Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
3a3a250
Implement RayExecutorV2 & tested on a single-node
jeffreywang-anyscale Mar 10, 2026
df75664
Enable multinode
jeffreywang-anyscale Mar 12, 2026
bbaa21b
Fix pre-commit
jeffreywang-anyscale Mar 16, 2026
2541f2d
Fix RayExecutorV2 monitor thread self-join
jeffreywang-anyscale Mar 16, 2026
c3ad8e5
Remove unnecessary changes
Mar 17, 2026
300d0ae
Extract bundle sorting to a utility
jeffreywang-anyscale Mar 17, 2026
11d32eb
Fix linter
jeffreywang-anyscale Mar 17, 2026
5795f1d
Enable async scheduling
jeffreywang-anyscale Mar 18, 2026
7128074
Address CR feedback
jeffreywang-anyscale Mar 19, 2026
e7a3c1f
Address test feedback
jeffreywang-anyscale Mar 19, 2026
5b4119a
Merge branch 'main' into ray
jeffreywang-anyscale Mar 19, 2026
ec2730d
Iterate over world_size
jeffreywang-anyscale Mar 19, 2026
ca95900
Fix tests and linters
jeffreywang-anyscale Mar 19, 2026
139c02a
Respect VLLM_RAY_BUNDLE_INDICES
jeffreywang-anyscale Mar 22, 2026
7657031
Adjust DP rank for ray executor backend
jeffreywang-anyscale Mar 23, 2026
6c1ea7e
Apply DP local-rank device offset for RayExecutorV2 workers
jeffreywang-anyscale Mar 23, 2026
d040317
Support DP
jeffreywang-anyscale Mar 24, 2026
a76acc9
Fix linter
jeffreywang-anyscale Mar 24, 2026
c9f0a39
Lazily initialize RayWorkerProc
jeffreywang-anyscale Mar 25, 2026
29c7426
Propagate env var; add tests
jeffreywang-anyscale Mar 25, 2026
aae5938
Add nsight profiling and non-GPU device support to RayExecutorV2
jeffreywang-anyscale Mar 25, 2026
25eaf8e
Fix AsyncLLMActor async detection in e2e tests
jeffreywang-anyscale Mar 26, 2026
6717ca2
Fix AsyncLLMActor async detection in e2e tests
jeffreywang-anyscale Mar 26, 2026
cfba15e
Fix test
jeffreywang-anyscale Mar 26, 2026
476501b
Fix test
jeffreywang-anyscale Mar 26, 2026
c7aa661
Fix wrong PYTHONPATH in Ray workers
jeffreywang-anyscale Mar 26, 2026
e0fd321
CR feedback round 1
jeffreywang-anyscale Mar 30, 2026
af21cdd
CR feedback round 2
jeffreywang-anyscale Mar 30, 2026
ad8f6d0
Only apply blacklist & propagate env with setdefault
jeffreywang-anyscale Mar 30, 2026
605a347
Merge branch 'main' into ray
jeffreywang-anyscale Mar 31, 2026
7586204
CR feedback round 3
jeffreywang-anyscale Mar 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions .buildkite/test_areas/distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,20 @@ steps:
commands:
- ./.buildkite/scripts/run-multi-node-test.sh /vllm-workspace/tests 2 2 $IMAGE_TAG "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=0 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py" "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=1 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code"

- label: MessageQueue TCP Multi-Node (2 GPUs)
timeout_in_minutes: 10
working_dir: "/vllm-workspace/tests"
num_devices: 1
num_nodes: 2
no_plugin: true
optional: true
source_file_dependencies:
- vllm/distributed/device_communicators/shm_broadcast.py
- vllm/distributed/parallel_state.py
- tests/distributed/test_mq_tcp_multinode.py
commands:
- ./.buildkite/scripts/run-multi-node-test.sh /vllm-workspace/tests 2 1 $IMAGE_TAG "torchrun --nnodes 2 --nproc-per-node=1 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_mq_tcp_multinode.py" "torchrun --nnodes 2 --nproc-per-node=1 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_mq_tcp_multinode.py"

- label: Distributed NixlConnector PD accuracy (4 GPUs)
timeout_in_minutes: 30
working_dir: "/vllm-workspace/tests"
Expand Down Expand Up @@ -294,3 +308,23 @@ steps:
commands:
- pytest -v -s distributed/test_pp_cudagraph.py
- pytest -v -s distributed/test_pipeline_parallel.py

- label: RayExecutorV2 (4 GPUs)
timeout_in_minutes: 60
working_dir: "/vllm-workspace/tests"
num_devices: 4
source_file_dependencies:
- vllm/v1/executor/ray_executor_v2.py
- vllm/v1/executor/abstract.py
- vllm/v1/executor/multiproc_executor.py
- tests/distributed/test_ray_v2_executor.py
- tests/distributed/test_ray_v2_executor_e2e.py
- tests/distributed/test_pipeline_parallel.py
- tests/basic_correctness/test_basic_correctness.py
commands:
- export VLLM_USE_RAY_V2_EXECUTOR_BACKEND=1
- export NCCL_CUMEM_HOST_ENABLE=0
- pytest -v -s distributed/test_ray_v2_executor.py
- pytest -v -s distributed/test_ray_v2_executor_e2e.py
- pytest -v -s distributed/test_pipeline_parallel.py -k "ray"
- TARGET_TEST_SUITE=L4 pytest -v -s basic_correctness/test_basic_correctness.py -k "ray"
29 changes: 29 additions & 0 deletions tests/distributed/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import random

import msgspec
Expand Down Expand Up @@ -166,3 +167,31 @@ def close(self):
self.sub.close()
for replay in self.replay_sockets:
replay.close()


@pytest.fixture
def enable_ray_v2_backend():
"""Set env vars for the Ray V2 executor backend and shut down Ray
between tests."""
import ray

saved = {
"VLLM_USE_RAY_V2_EXECUTOR_BACKEND": os.environ.get(
"VLLM_USE_RAY_V2_EXECUTOR_BACKEND"
),
"VLLM_ENABLE_V1_MULTIPROCESSING": os.environ.get(
"VLLM_ENABLE_V1_MULTIPROCESSING"
),
}
os.environ["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1"
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
if ray.is_initialized():
ray.shutdown()
try:
yield
finally:
if ray.is_initialized():
ray.shutdown()
os.environ.update({k: v for k, v in saved.items() if v is not None})
for key in (k for k, v in saved.items() if v is None):
os.environ.pop(key, None)
119 changes: 119 additions & 0 deletions tests/distributed/test_mq_tcp_multinode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Multi-node integration test for MessageQueue TCP fallback.

Verifies that when writer and readers span separate nodes (Docker containers
with isolated /dev/shm), `create_from_process_group` correctly detects
cross-node ranks via `in_the_same_node_as()` and falls back to ZMQ TCP
transport — and that data actually arrives.
"""

import numpy as np
import torch.distributed as dist

from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.distributed.parallel_state import in_the_same_node_as


def main():
dist.init_process_group(backend="gloo")

rank = dist.get_rank()
world_size = dist.get_world_size()
assert world_size >= 2, (
f"Need at least 2 ranks across nodes, got world_size={world_size}"
)

# Verify that in_the_same_node_as detects cross-node correctly
status = in_the_same_node_as(dist.group.WORLD, source_rank=0)
local_count = sum(status)
print(
f"[Rank {rank}] in_the_same_node_as(source=0): {status} "
f"(local={local_count}/{world_size})"
)
# With 2 Docker containers (1 proc each), rank 0 and rank 1
# should be on different nodes.
assert local_count < world_size, (
f"Expected cross-node ranks but all {world_size} ranks appear local."
)

# Create MessageQueue
writer_rank = 0
mq = MessageQueue.create_from_process_group(
dist.group.WORLD,
max_chunk_bytes=1024 * 1024, # 1 MiB
max_chunks=10,
writer_rank=writer_rank,
)

# Verify the transport path selection
if rank == writer_rank:
print(
f"[Rank {rank}] Writer: n_local_reader={mq.n_local_reader}, "
f"n_remote_reader={mq.n_remote_reader}"
)
assert mq.n_remote_reader > 0, (
"Writer should have at least 1 remote (TCP) reader in a multi-node setup."
)
else:
if status[rank]:
assert mq._is_local_reader, (
f"Rank {rank} is on the same node as writer but is not a local reader."
)
print(f"[Rank {rank}] Reader: local (shared memory)")
else:
assert mq._is_remote_reader, (
f"Rank {rank} is on a different node but is not a remote (TCP) reader."
)
print(f"[Rank {rank}] Reader: remote (TCP)")

# Test data transfer: simple objects
dist.barrier()
if rank == writer_rank:
mq.enqueue("hello_from_node0")
else:
msg = mq.dequeue(timeout=10)
assert msg == "hello_from_node0"
dist.barrier()
print(f"[Rank {rank}] Simple object test passed")

# Test data transfer: numpy arrays
np.random.seed(42)
arrays = [
np.random.randint(0, 100, size=np.random.randint(100, 5000)) for _ in range(100)
]

dist.barrier()
if rank == writer_rank:
for arr in arrays:
mq.enqueue(arr)
else:
for i, expected in enumerate(arrays):
received = mq.dequeue(timeout=10)
assert np.array_equal(expected, received), (
f"Array mismatch at index {i}: "
f"expected shape {expected.shape}, got shape {received.shape}"
)
dist.barrier()
print(f"[Rank {rank}] Numpy array test passed")

# Test data transfer: large payload (> max_chunk_bytes)
dist.barrier()
big_array = np.zeros(200_000, dtype=np.int64) # ~1.6 MiB > 1 MiB chunk
if rank == writer_rank:
mq.enqueue(big_array)
else:
received = mq.dequeue(timeout=10)
assert np.array_equal(big_array, received)
dist.barrier()
print(f"[Rank {rank}] Large payload test passed")

# Done -- cleanup
dist.barrier()
print(f"[Rank {rank}] All MessageQueue TCP multi-node tests passed!")
dist.destroy_process_group()


if __name__ == "__main__":
main()
Loading
Loading