Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
67623aa
wip
joerunde Oct 31, 2025
3ce3a48
:zap: functional SpinCondition
joerunde Nov 4, 2025
21c7b04
:art: cleanup
joerunde Nov 4, 2025
b296ad5
:art: fmt
joerunde Nov 4, 2025
0b84082
:art: cleanup
joerunde Nov 4, 2025
6a36b18
Merge branch 'main' into new-poll-fix
joerunde Nov 4, 2025
8037303
Merge branch 'main' into new-poll-fix
joerunde Nov 13, 2025
00c4c3f
:poop: WIP unit tests
joerunde Nov 13, 2025
f55c68e
test: flesh out shm_broadcast tests
tjohnson31415 Nov 17, 2025
0bd12b3
fix timeout handling and little refactor
tjohnson31415 Nov 17, 2025
74cc6d5
Merge branch 'main' into new-poll-fix
tjohnson31415 Nov 18, 2025
9c97af5
Merge branch 'main' into new-poll-fix
tjohnson31415 Nov 20, 2025
73c3398
test: add busy shutdown test
tjohnson31415 Dec 1, 2025
184db2a
Merge branch 'main' into new-poll-fix
joerunde Dec 4, 2025
28a4e60
Merge branch 'main' into new-poll-fix
tjohnson31415 Dec 18, 2025
3e47bf4
Merge branch 'main' into new-poll-fix
joerunde Jan 21, 2026
7749b21
:bug: fix uninitialized spin condition
joerunde Jan 21, 2026
9ff2361
:rewind: revert changes from #32965
joerunde Jan 23, 2026
0f0ccf6
:recycle: refactor timeout stuff
joerunde Feb 18, 2026
dbac53a
refactor: make ReadTimeout class clearer
tjohnson31415 Feb 18, 2026
de3f4a6
rename: ReadTimeoutWithWarnings
tjohnson31415 Feb 18, 2026
b8b56f0
:test_tube: add negative test for warning logs
joerunde Feb 18, 2026
1affbef
:bug: fix test hangs
joerunde Feb 19, 2026
40051c5
test: distributed_run fail fast
tjohnson31415 Feb 19, 2026
18febb3
refactor: move monitor_parent_death to be a worker method
tjohnson31415 Feb 19, 2026
6038f89
refactor: cleanup new monitor function
tjohnson31415 Feb 19, 2026
df8dcfe
Merge branch 'main' into new-poll-fix
joerunde Feb 20, 2026
5f14af2
review: changes from review
tjohnson31415 Feb 24, 2026
2f3f98c
fix: handle inherited socket connections when forking
tjohnson31415 Feb 24, 2026
c45326b
log: add some debug logs to ensure_worker_termination
tjohnson31415 Feb 25, 2026
e1565e0
fix: shutdown all queues in MultiprocExecutor
tjohnson31415 Feb 25, 2026
54ff00c
Merge branch 'main' into new-poll-fix
tjohnson31415 Feb 25, 2026
f7e3486
log: add logging to SpinCondition wait
tjohnson31415 Mar 2, 2026
1f04a71
Merge branch 'main' into new-poll-fix
tjohnson31415 Mar 2, 2026
26ee621
Merge branch 'main' into new-poll-fix
tjohnson31415 Mar 2, 2026
9be917a
Merge branch 'main' into new-poll-fix
tjohnson31415 Mar 3, 2026
cd06f1f
Merge branch 'main' into new-poll-fix
tjohnson31415 Mar 3, 2026
765660f
cleanup removed env var
njhill Mar 3, 2026
97eb604
minor code simplification
njhill Mar 3, 2026
364ce6c
Merge branch 'main' into new-poll-fix
tjohnson31415 Mar 3, 2026
8bf1da6
fix: time in log message now rounds to 0
tjohnson31415 Mar 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@ def test_models(
[
("facebook/opt-125m", "ray", "", "L4", {}),
("facebook/opt-125m", "mp", "", "L4", {}),
("facebook/opt-125m", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
("facebook/opt-125m", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}),
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}),
("facebook/opt-125m", "ray", "", "A100", {}),
Expand Down
293 changes: 285 additions & 8 deletions tests/distributed/test_shm_broadcast.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import multiprocessing
import random
import threading
import time
from unittest import mock

import multiprocess as mp
import numpy as np
import pytest
import torch.distributed as dist

from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
Expand All @@ -22,7 +25,14 @@ def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]:
return [np.random.randint(1, 100, i) for i in sizes]


def distributed_run(fn, world_size):
def distributed_run(fn, world_size, timeout=60):
"""Run a function in multiple processes with proper error handling.

Args:
fn: Function to run in each process
world_size: Number of processes to spawn
timeout: Maximum time in seconds to wait for processes (default: 60)
"""
number_of_processes = world_size
processes = []
for i in range(number_of_processes):
Expand All @@ -33,19 +43,45 @@ def distributed_run(fn, world_size):
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = "12345"
p = multiprocessing.Process(target=fn, args=(env,))
p = mp.Process(target=fn, args=(env,))
processes.append(p)
p.start()

for p in processes:
p.join()
# Monitor processes and fail fast if any process fails
start_time = time.time()
failed_processes = []

# Wait for all processes, checking for failures
while time.time() - start_time < timeout:
all_done = True
for i, p in enumerate(processes):
if p.is_alive():
all_done = False
elif p.exitcode != 0:
# Process failed
failed_processes.append((i, p.exitcode))
break

if failed_processes or all_done:
break
time.sleep(0.1) # Check every 100ms

for p in processes:
assert p.exitcode == 0
# Check for timeout if no failures detected yet
for i, p in enumerate(processes):
if p.is_alive():
p.kill()
p.join()

# Report failures
if failed_processes:
error_msg = "Distributed test failed:\n"
for rank, status in failed_processes:
error_msg += f" Rank {rank}: Exit code {status}\n"
raise AssertionError(error_msg)


def worker_fn_wrapper(fn):
# `multiprocessing.Process` cannot accept environment variables directly
# `mp.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def wrapped_fn(env):
Expand Down Expand Up @@ -115,3 +151,244 @@ def worker_fn():

def test_shm_broadcast():
distributed_run(worker_fn, 4)


@worker_fn_wrapper
def worker_fn_test_shutdown_busy():
rank = dist.get_rank()
writer_rank = 2
message_queue = MessageQueue.create_from_process_group(
dist.group.WORLD, 40 * 1024, 2, writer_rank
)

if not message_queue._is_writer:
# Put into busy mode
message_queue._spin_condition.busy_loop_s = 9999

shutdown_event = threading.Event()

def shutdown_thread(mq, shutdown_event):
shutdown_event.wait()
mq.shutdown()

threading.Thread(
target=shutdown_thread, args=(message_queue, shutdown_event)
).start()

with pytest.raises(TimeoutError):
message_queue.dequeue(timeout=0.01)

shutdown_event.set()

with pytest.raises(RuntimeError, match="cancelled"):
message_queue.dequeue(timeout=1)

assert message_queue.shutting_down

print(f"torch distributed passed the test! Rank {rank}")
dist.barrier()


def test_message_queue_shutdown_busy(caplog_vllm):
distributed_run(worker_fn_test_shutdown_busy, 4)
print(caplog_vllm.text)


@worker_fn_wrapper
def worker_fn_test_shutdown_idle():
rank = dist.get_rank()
writer_rank = 2
message_queue = MessageQueue.create_from_process_group(
dist.group.WORLD, 40 * 1024, 2, writer_rank
)

if not message_queue._is_writer:
# Put into idle mode
message_queue._spin_condition.last_read = 0

shutdown_event = threading.Event()

def shutdown_thread(mq, shutdown_event):
shutdown_event.wait()
mq.shutdown()

threading.Thread(
target=shutdown_thread, args=(message_queue, shutdown_event)
).start()

with pytest.raises(TimeoutError):
message_queue.dequeue(timeout=0.01)

shutdown_event.set()

with pytest.raises(RuntimeError, match="cancelled"):
message_queue.dequeue(timeout=1)

assert message_queue.shutting_down

print(f"torch distributed passed the test! Rank {rank}")
dist.barrier()


def test_message_queue_shutdown_idle():
distributed_run(worker_fn_test_shutdown_idle, 4)


@worker_fn_wrapper
def worker_fn_test_idle_to_busy():
rank = dist.get_rank()
writer_rank = 2
message_queue = MessageQueue.create_from_process_group(
dist.group.WORLD, 40 * 1024, 2, writer_rank
)

message1 = "hello world"
message2 = np.random.randint(1, 100, 100)
with mock.patch.object(
message_queue._spin_condition, "wait", wraps=message_queue._spin_condition.wait
) as wrapped_wait:
if not message_queue._is_writer:
# Put into idle mode
message_queue._spin_condition.last_read = 0

# no messages, so expect a TimeoutError
with pytest.raises(TimeoutError):
message_queue.dequeue(timeout=0.01)
# wait should only be called once while idle
assert wrapped_wait.call_count == 1

# sync with the writer and wait for message1
dist.barrier()
recv_message = message_queue.dequeue(timeout=5)
assert recv_message == message1
# second call to wait, with a message read, this puts in a busy spin
assert wrapped_wait.call_count == 2

# sync with the writer and wait for message2
dist.barrier()
recv_message = message_queue.dequeue(timeout=1)
assert np.array_equal(recv_message, message2)
# in busy mode, we expect wait to have been called multiple times
assert wrapped_wait.call_count > 3
else:
# writer writes two messages in sync with the reader
dist.barrier()
# sleep delays the send to ensure reader enters the read loop
time.sleep(0.1)
message_queue.enqueue(message1)

dist.barrier()
time.sleep(0.1)
message_queue.enqueue(message2)

message_queue.shutdown()
assert message_queue.shutting_down
print(f"torch distributed passed the test! Rank {rank}")


def test_message_queue_idle_wake():
distributed_run(worker_fn_test_idle_to_busy, 4)


@worker_fn_wrapper
def worker_fn_test_busy_to_idle():
rank = dist.get_rank()
writer_rank = 2
message_queue = MessageQueue.create_from_process_group(
dist.group.WORLD, 40 * 1024, 2, writer_rank
)

message1 = 12345
message2 = list(range(3))
with mock.patch.object(
message_queue._spin_condition, "wait", wraps=message_queue._spin_condition.wait
) as wrapped_wait:
if not message_queue._is_writer:
# Put into busy mode
message_queue._spin_condition.busy_loop_s = 9999

# sync with the writer and wait for message1
dist.barrier()
recv_message = message_queue.dequeue(timeout=1)
assert recv_message == message1
# in busy mode, we expect wait to have been called many times
assert wrapped_wait.call_count > 1

# simulate busy loop ending
message_queue._spin_condition.busy_loop_s = 0
# ensure we enter idle mode, then record call count
with pytest.raises(TimeoutError):
message_queue.dequeue(timeout=0.01)
call_count = wrapped_wait.call_count

# sync with the writer and wait for message2
dist.barrier()
recv_message = message_queue.dequeue(timeout=1)
assert recv_message == message2

# call to wait after idle should only happen once
assert wrapped_wait.call_count == call_count + 1
else:
# writer writes two messages in sync with the reader
dist.barrier()
# sleep delays the send to ensure reader enters the read loop
time.sleep(0.1)
message_queue.enqueue(message1)

dist.barrier()
time.sleep(0.1)
message_queue.enqueue(message2)

message_queue.shutdown()
assert message_queue.shutting_down
print(f"torch distributed passed the test! Rank {rank}")


def test_message_queue_busy_to_idle():
distributed_run(worker_fn_test_busy_to_idle, 4)


def test_warning_logs(caplog_vllm):
"""
Test that warning logs are emitted at VLLM_RINGBUFFER_WARNING_INTERVAL intervals
when indefinite=False, and are not emitted when indefinite=True.
"""

# Patch the warning log interval to every 1 ms during reads
with mock.patch(
"vllm.distributed.device_communicators.shm_broadcast.VLLM_RINGBUFFER_WARNING_INTERVAL",
new=0.001, # 1 ms
):
writer = MessageQueue(
n_reader=1,
n_local_reader=1,
max_chunk_bytes=1024 * 1024, # 1MB chunks
max_chunks=10,
)
reader = MessageQueue.create_from_handle(writer.export_handle(), rank=0)
writer.wait_until_ready()
reader.wait_until_ready()

# We should have at least one warning log here
# "0 seconds" expected due to rounding of 1ms test interval
with pytest.raises(TimeoutError):
reader.dequeue(timeout=0.01, indefinite=False)
assert any(
"No available shared memory broadcast block found in 0 seconds"
in record.message
for record in caplog_vllm.records
)
caplog_vllm.clear()

# We should have no warnings this time
with pytest.raises(TimeoutError):
reader.dequeue(timeout=0.01, indefinite=True)
assert all(
"No available shared memory broadcast block found in 0 seconds"
not in record.message
for record in caplog_vllm.records
)

# Clean up when done
writer.shutdown()
reader.shutdown()
Loading