Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
876bab1
support non-ray DI
luccafong Aug 19, 2025
c1fec8a
add tests
luccafong Sep 8, 2025
a748250
fix lint
luccafong Oct 10, 2025
3db52e0
merge mp and mp distributed
luccafong Oct 17, 2025
058d9d5
clean up distributed executor and address comments
luccafong Oct 17, 2025
2d56d06
add executor test
luccafong Oct 18, 2025
1913bd3
Merge remote-tracking branch 'origin/main' into lucia/non-ray-di
luccafong Oct 18, 2025
8984b7d
exector only
luccafong Oct 24, 2025
bfb3dd7
revert change to core.py
luccafong Oct 27, 2025
aad47eb
address minor comments
luccafong Oct 27, 2025
e868c24
compatibility with dp
luccafong Oct 30, 2025
b20664c
Merge branch 'main' into lucia/non-ray-di
luccafong Oct 30, 2025
55439d6
fix lint
luccafong Oct 30, 2025
2d401b0
Merge remote-tracking branch 'origin/main' into lucia/non-ray-di
luccafong Nov 7, 2025
bd8614b
add more comments
luccafong Nov 7, 2025
3733404
simplify process
njhill Nov 7, 2025
0291991
bit more cleanup
njhill Nov 8, 2025
7ceea17
move method
njhill Nov 8, 2025
8ca2301
simply lb mode selection and fix serving and lint issues
luccafong Nov 10, 2025
05668ce
fix test
luccafong Nov 10, 2025
2938621
Merge remote-tracking branch 'origin/main' into lucia/non-ray-di
luccafong Nov 10, 2025
d5463a7
add assertion
luccafong Nov 10, 2025
f23904a
add back version number in launch msg
luccafong Nov 12, 2025
acc44db
add more comments explaining blocking
luccafong Nov 12, 2025
93e606a
arg change and model executor cleanup
luccafong Nov 14, 2025
3074132
clean up gpu model runner code that is not needed
luccafong Nov 14, 2025
046cb67
fix lint
luccafong Nov 14, 2025
6de8648
remove print
luccafong Nov 14, 2025
3178648
address commment
luccafong Nov 14, 2025
7595a25
fix test for the positional arg used
luccafong Nov 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
437 changes: 437 additions & 0 deletions tests/distributed/test_multiproc_executor.py

Large diffs are not rendered by default.

40 changes: 40 additions & 0 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,18 @@ class ParallelConfig:
class is dynamically inherited by the worker class. This is used to inject
new attributes and methods to the worker class for use in collective_rpc
calls."""
master_addr: str = "127.0.0.1"
"""distributed master address for multi-node distributed
inference when distributed_executor_backend is mp."""
master_port: int = 29501
"""distributed master port for multi-node distributed
inference when distributed_executor_backend is mp."""
node_rank: int = 0
"""distributed node rank for multi-node distributed
inference when distributed_executor_backend is mp."""
nnodes: int = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's a good idea to add nnodes as an attribute to the very general ParallelConfig class where it is just meaningful when distributed_executor_backend is mp. Many AI labs rely heavily on RAY and when the distibuted executor backend is ray, nnodes will always be wrong here is it'll stay 1 even if tensor_parallel is set to something like 16 on 8xH200.

It's not intuitive when doing self.vllm_config.parallel_config.nnodes to get 1 here for ray backend and it's not enough to just state in the comment that it's only for mp. IMO we need to make sure that nnodes always displays the correct number of nodes no matter the parallel backend. Similarly node_rank also needs to make sense for ray. If it's impossible to put good values for ray we should at least prefix it with nnodes_for_mp_backend

Copy link
Collaborator Author

@luccafong luccafong Nov 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nnodes_for_mp_backend might be pretty hard to use, does set them to None be a better way to resolve the conflicts and confusion? and we can raise error if user set it for ray. @patrickvonplaten

"""num of nodes for multi-node distributed
inference when distributed_executor_backend is mp."""

world_size: int = Field(init=False)
"""world_size is TPxPP, it affects the number of workers we create."""
Expand Down Expand Up @@ -387,6 +399,23 @@ def use_sequence_parallel_moe(self) -> bool:
and self.data_parallel_size > 1
)

@property
def node_rank_within_dp(self) -> int:
return self.node_rank % self.nnodes_within_dp

@property
def nnodes_within_dp(self) -> int:
if self.nnodes == 1:
return 1
data_parallel_node_size = (
self.data_parallel_size // self.data_parallel_size_local
)
return self.nnodes // data_parallel_node_size

@property
def local_world_size(self) -> int:
return self.world_size // self.nnodes_within_dp

@staticmethod
def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool:
tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="cpu")
Expand Down Expand Up @@ -528,6 +557,8 @@ def __post_init__(self) -> None:
ray_found = ray_utils.ray_is_available()
if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
backend = "uni"
elif current_platform.is_cuda() and self.nnodes > 1:
backend = "mp"
elif (
current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size
Expand Down Expand Up @@ -565,6 +596,10 @@ def __post_init__(self) -> None:
"max_parallel_loading_workers is currently "
"not supported and will be ignored."
)
if self.distributed_executor_backend != "mp" and self.nnodes > 1:
raise ValueError(
"nnodes > 1 can only be set when distributed exectuor backend is mp."
)

@property
def use_ray(self) -> bool:
Expand Down Expand Up @@ -607,6 +642,11 @@ def _verify_args(self) -> Self:
"Disabled the custom all-reduce kernel because it is not "
"supported on current platform."
)
if self.nnodes > 1:
self.disable_custom_all_reduce = True
logger.debug(
"Disabled the custom all-reduce since we are running on multi-node."
)
if self.ray_workers_use_nsight and not self.use_ray:
raise ValueError(
"Unable to use nsight profiling unless workers run with Ray."
Expand Down
110 changes: 94 additions & 16 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from multiprocessing import shared_memory
from pickle import PickleBuffer
from threading import Event
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
from unittest.mock import patch

import torch
Expand Down Expand Up @@ -602,13 +602,87 @@ def broadcast_object(self, obj=None):
return obj
return self.dequeue()

@staticmethod
def create_from_process_group_single_reader(
pg: ProcessGroup,
max_chunk_bytes,
max_chunks,
reader_rank: int = 0,
blocking: bool = False,
) -> tuple["MessageQueue", list[Handle]]:
"""
Creates a MessageQueue for a process group with a single reader.

This method is designed for scenarios where only one process (the reader)
will consume messages, and all other processes are writers. It sets up
the shared memory buffer and communication handles accordingly, and
gathers the handles from all processes to the reader.

Args:
pg (ProcessGroup): The torch distributed process group.
max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
max_chunks (int): Maximum number of chunks in the buffer.
reader_rank (int, optional): The global rank that will act as the reader.
Defaults to 0.
blocking (bool, optional): If True, blocks until all processes are ready.
Defaults to False.

Returns:
tuple[MessageQueue, list[Handle]]:
The MessageQueue instance for the calling process,
and a list of handles (only non-empty for the reader process).
"""
local_size = torch.cuda.device_count()
rank = dist.get_rank()
same_node = rank // local_size == reader_rank // local_size
buffer_io = MessageQueue(
n_reader=1,
n_local_reader=1 if same_node else 0,
max_chunk_bytes=max_chunk_bytes,
max_chunks=max_chunks,
)
handle = buffer_io.export_handle()
handles = [None] * dist.get_world_size(pg) if rank == reader_rank else None
dist.gather_object(handle, handles, dst=reader_rank, group=pg)
if blocking:
buffer_io.wait_until_ready()
return buffer_io, cast(list[Handle], handles or [])

@staticmethod
def create_from_process_group(
pg: ProcessGroup | StatelessProcessGroup,
max_chunk_bytes,
max_chunks,
writer_rank=0,
writer_rank: int = 0,
external_writer_handle=None,
blocking: bool = True,
) -> "MessageQueue":
"""
Creates a MessageQueue for a distributed process group with one writer and
multiple readers.

This method is designed for scenarios where one process (the writer) sends
messages, and all other processes (the readers) receive messages. It sets up
the shared memory buffer and socket communication handles accordingly, and
broadcasts the handle from the writer to all readers.

Args:
pg (ProcessGroup | StatelessProcessGroup): The torch distributed process
group.
max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
max_chunks (int): Maximum number of chunks in the buffer.
writer_rank (int, optional): The global rank that will act as the writer.
Defaults to 0.
external_writer_handle (Handle, optional): Used when there is a handle
from an external Message Queue. If provided, use this handle to init
PG writer message queue instead of creating a new one. Defaults to None.
blocking (bool, optional): If True, blocks until all processes are ready.
Defaults to True.

Returns:
MessageQueue: The MessageQueue instance for the calling process.

"""
if isinstance(pg, ProcessGroup):
group_rank = dist.get_rank(pg)
group_world_size = dist.get_world_size(pg)
Expand All @@ -617,23 +691,26 @@ def create_from_process_group(
group_rank = pg.rank
group_world_size = pg.world_size
global_ranks = list(range(pg.world_size))

from vllm.distributed.parallel_state import in_the_same_node_as

status = in_the_same_node_as(pg, source_rank=writer_rank)
same_node_ranks = [i for i, s in enumerate(status) if s]
n_reader = group_world_size - 1
n_local_reader = len(same_node_ranks) - 1
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
buffer_io: MessageQueue
if group_rank == writer_rank:
buffer_io = MessageQueue(
n_reader=n_reader,
n_local_reader=n_local_reader,
local_reader_ranks=local_reader_ranks,
max_chunk_bytes=max_chunk_bytes,
max_chunks=max_chunks,
)
if external_writer_handle is not None:
buffer_io = MessageQueue.create_from_handle(
external_writer_handle, group_rank
)
else:
same_node_ranks = [i for i, s in enumerate(status) if s]
n_reader = group_world_size - 1
n_local_reader = len(same_node_ranks) - 1
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
buffer_io = MessageQueue(
n_reader=n_reader,
n_local_reader=n_local_reader,
local_reader_ranks=local_reader_ranks,
max_chunk_bytes=max_chunk_bytes,
max_chunks=max_chunks,
)
handle = buffer_io.export_handle()
if isinstance(pg, ProcessGroup):
dist.broadcast_object_list(
Expand All @@ -651,5 +728,6 @@ def create_from_process_group(
else:
handle = pg.broadcast_obj(None, writer_rank)
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
buffer_io.wait_until_ready()
if blocking:
buffer_io.wait_until_ready()
return buffer_io
77 changes: 73 additions & 4 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,33 @@ def __init__(
torch.ops._C, "init_shm_manager"
)

def create_mq_broadcaster(
self, writer_rank=0, external_writer_handle=None, blocking=True
):
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

return MessageQueue.create_from_process_group(
self.cpu_group,
1 << 22,
6,
writer_rank=writer_rank,
external_writer_handle=external_writer_handle,
blocking=blocking,
)

def create_single_reader_mq_broadcasters(
self, reader_rank_in_group=0, blocking=False
):
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

return MessageQueue.create_from_process_group_single_reader(
self.cpu_group,
1 << 22,
6,
reader_rank=self.ranks[reader_rank_in_group],
blocking=blocking,
)

@property
def first_rank(self):
"""Return the global rank of the first process in the group"""
Expand Down Expand Up @@ -997,6 +1024,7 @@ def combine(


_WORLD: GroupCoordinator | None = None
_INNER_DP_WORLD: GroupCoordinator | None = None
_NODE_COUNT: int | None = None


Expand All @@ -1005,6 +1033,11 @@ def get_world_group() -> GroupCoordinator:
return _WORLD


def get_inner_dp_world_group() -> GroupCoordinator:
assert _INNER_DP_WORLD is not None, "inner dp world group is not initialized"
return _INNER_DP_WORLD


def init_world_group(
ranks: list[int], local_rank: int, backend: str
) -> GroupCoordinator:
Expand All @@ -1023,12 +1056,13 @@ def init_model_parallel_group(
backend: str,
use_message_queue_broadcaster: bool = False,
group_name: str | None = None,
use_device_communicator: bool = True,
) -> GroupCoordinator:
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_device_communicator=True,
use_device_communicator=use_device_communicator,
use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
)
Expand Down Expand Up @@ -1143,7 +1177,14 @@ def init_distributed_environment(
from vllm.config import get_current_vllm_config

config = get_current_vllm_config()
if (
if config is not None and config.parallel_config.nnodes > 1:
parallel_config = config.parallel_config
ip = parallel_config.master_addr
rank = parallel_config.data_parallel_rank * world_size + rank
world_size = parallel_config.world_size_across_dp
port = parallel_config.master_port
distributed_init_method = get_distributed_init_method(ip, port)
elif (
config is not None
and config.parallel_config.data_parallel_size > 1
and config.parallel_config.distributed_executor_backend != "external_launcher"
Expand All @@ -1164,6 +1205,14 @@ def init_distributed_environment(
distributed_init_method,
)
if not torch.distributed.is_initialized():
logger.info(
"world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
world_size,
rank,
local_rank,
distributed_init_method,
backend,
)
assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing "
"distributed environment"
Expand Down Expand Up @@ -1192,16 +1241,36 @@ def init_distributed_environment(
# local rank not set, this usually happens in single-node
# setting, where we can use rank as local rank
local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
global _WORLD, _NODE_COUNT
global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend)
_NODE_COUNT = _node_count(_WORLD.cpu_group)
if config.parallel_config.nnodes > 1:
_NODE_COUNT = config.parallel_config.nnodes
else:
_NODE_COUNT = _node_count(_WORLD.cpu_group)
logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT)
else:
assert _WORLD.world_size == torch.distributed.get_world_size(), (
"world group already initialized with a different world size"
)
if config.parallel_config.nnodes_within_dp > 1:
if parallel_config.data_parallel_size > 1:
world_size_inner_dp = parallel_config.world_size
group_ranks = [
[dp_rank * world_size_inner_dp + i for i in range(world_size_inner_dp)]
for dp_rank in range(parallel_config.data_parallel_size)
]
_INNER_DP_WORLD = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name="inner_dp_world",
use_device_communicator=False,
)
else:
_INNER_DP_WORLD = _WORLD


def initialize_model_parallel(
Expand Down
Loading