Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
876bab1
support non-ray DI
luccafong Aug 19, 2025
c1fec8a
add tests
luccafong Sep 8, 2025
a748250
fix lint
luccafong Oct 10, 2025
3db52e0
merge mp and mp distributed
luccafong Oct 17, 2025
058d9d5
clean up distributed executor and address comments
luccafong Oct 17, 2025
2d56d06
add executor test
luccafong Oct 18, 2025
1913bd3
Merge remote-tracking branch 'origin/main' into lucia/non-ray-di
luccafong Oct 18, 2025
8984b7d
exector only
luccafong Oct 24, 2025
bfb3dd7
revert change to core.py
luccafong Oct 27, 2025
aad47eb
address minor comments
luccafong Oct 27, 2025
e868c24
compatibility with dp
luccafong Oct 30, 2025
b20664c
Merge branch 'main' into lucia/non-ray-di
luccafong Oct 30, 2025
55439d6
fix lint
luccafong Oct 30, 2025
2d401b0
Merge remote-tracking branch 'origin/main' into lucia/non-ray-di
luccafong Nov 7, 2025
bd8614b
add more comments
luccafong Nov 7, 2025
3733404
simplify process
njhill Nov 7, 2025
0291991
bit more cleanup
njhill Nov 8, 2025
7ceea17
move method
njhill Nov 8, 2025
8ca2301
simply lb mode selection and fix serving and lint issues
luccafong Nov 10, 2025
05668ce
fix test
luccafong Nov 10, 2025
2938621
Merge remote-tracking branch 'origin/main' into lucia/non-ray-di
luccafong Nov 10, 2025
d5463a7
add assertion
luccafong Nov 10, 2025
f23904a
add back version number in launch msg
luccafong Nov 12, 2025
acc44db
add more comments explaining blocking
luccafong Nov 12, 2025
93e606a
arg change and model executor cleanup
luccafong Nov 14, 2025
3074132
clean up gpu model runner code that is not needed
luccafong Nov 14, 2025
046cb67
fix lint
luccafong Nov 14, 2025
6de8648
remove print
luccafong Nov 14, 2025
3178648
address commment
luccafong Nov 14, 2025
7595a25
fix test for the positional arg used
luccafong Nov 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
437 changes: 437 additions & 0 deletions tests/distributed/test_multiproc_executor.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion tools/pre_commit/check_pickle_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
"benchmarks/cutlass_benchmarks/w8a8_benchmarks.py",
"benchmarks/cutlass_benchmarks/sparse_benchmarks.py",
# cloudpickle
"vllm/executor/mp_distributed_executor.py",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove this one?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file is deleted already

"vllm/executor/ray_distributed_executor.py",
"vllm/entrypoints/llm.py",
"tests/utils.py",
Expand Down
17 changes: 17 additions & 0 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ class EPLBConfig:
class ParallelConfig:
"""Configuration for the distributed execution."""

distributed_master_ip: str = "127.0.0.1"
"""distributed master ip for multi-node distributed
inference when distributed_executor_backend is mp."""
distributed_master_port: int = 0
"""distributed master port """
distributed_node_rank: int = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to give some example to elaborate the meaning to distributed_node_rank and distributed_node_size.

"""distributed node rank """
distributed_node_size: int = 1
"""distributed node size """
pipeline_parallel_size: int = 1
"""Number of pipeline parallel groups."""
tensor_parallel_size: int = 1
Expand Down Expand Up @@ -516,9 +525,12 @@ def __post_init__(self) -> None:
ray_found = ray_utils.ray_is_available()
if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
backend = "uni"
elif current_platform.is_cuda() and self.distributed_node_size > 1:
backend = "mp"
elif (
current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size
and self.distributed_node_size == 1
):
if not ray_found:
raise ValueError(
Expand Down Expand Up @@ -593,6 +605,11 @@ def _verify_args(self) -> Self:
"Disabled the custom all-reduce kernel because it is not "
"supported on current platform."
)
if self.distributed_node_size > 1:
self.disable_custom_all_reduce = True
logger.debug(
"Disabled the custom all-reduce since we are running on multi-node."
)
if self.ray_workers_use_nsight and not self.use_ray:
raise ValueError(
"Unable to use nsight profiling unless workers run with Ray."
Expand Down
58 changes: 48 additions & 10 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,12 +594,42 @@ def broadcast_object(self, obj=None):
return obj
return self.dequeue()

@staticmethod
def create_from_process_group_single_reader(
pg: ProcessGroup,
max_chunk_bytes,
max_chunks,
reader_rank: int = 0,
blocking: bool = False,
):
"""Create a message queue from ranks."""
# We assume same size acrsso groups
local_size = torch.cuda.device_count()
group_rank = dist.get_rank(pg)
same_node = group_rank // local_size == reader_rank // local_size
buffer_io = MessageQueue(
n_reader=1,
n_local_reader=1 if same_node else 0,
max_chunk_bytes=max_chunk_bytes,
max_chunks=max_chunks,
)
handle = buffer_io.export_handle()
handles = (
[None] * dist.get_world_size(pg) if group_rank == reader_rank else None
)
dist.gather_object(handle, handles, dst=reader_rank, group=pg)
if blocking:
buffer_io.wait_until_ready()
return (buffer_io, handles if handles is not None else [])

@staticmethod
def create_from_process_group(
pg: ProcessGroup | StatelessProcessGroup,
max_chunk_bytes,
max_chunks,
writer_rank=0,
writer_rank: int = 0,
extra_writer_handler=None,
blocking: bool = True,
) -> "MessageQueue":
if isinstance(pg, ProcessGroup):
group_rank = dist.get_rank(pg)
Expand All @@ -609,23 +639,30 @@ def create_from_process_group(
group_rank = pg.rank
group_world_size = pg.world_size
global_ranks = list(range(pg.world_size))

from vllm.distributed.parallel_state import in_the_same_node_as

status = in_the_same_node_as(pg, source_rank=writer_rank)
same_node_ranks = [i for i, s in enumerate(status) if s]
n_reader = group_world_size - 1
n_local_reader = len(same_node_ranks) - 1
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
if extra_writer_handler is not None:
n_reader = group_world_size
n_local_reader = len(same_node_ranks)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why separate if for these here, could go in the else below

buffer_io: MessageQueue
if group_rank == writer_rank:
buffer_io = MessageQueue(
n_reader=n_reader,
n_local_reader=n_local_reader,
local_reader_ranks=local_reader_ranks,
max_chunk_bytes=max_chunk_bytes,
max_chunks=max_chunks,
)
if extra_writer_handler is not None:
buffer_io = MessageQueue.create_from_handle(
extra_writer_handler, group_rank
)
else:
buffer_io = MessageQueue(
n_reader=n_reader,
n_local_reader=n_local_reader,
local_reader_ranks=local_reader_ranks,
max_chunk_bytes=max_chunk_bytes,
max_chunks=max_chunks,
)
handle = buffer_io.export_handle()
if isinstance(pg, ProcessGroup):
dist.broadcast_object_list(
Expand All @@ -643,5 +680,6 @@ def create_from_process_group(
else:
handle = pg.broadcast_obj(None, writer_rank)
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
buffer_io.wait_until_ready()
if blocking:
buffer_io.wait_until_ready()
return buffer_io
44 changes: 42 additions & 2 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,27 @@ def __init__(
torch.ops._C, "init_shm_manager"
)

def create_mq_broadcaster(
self, writer_rank=0, extra_writer_handler=None, blocking=True
):
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

return MessageQueue.create_from_process_group(
self.cpu_group,
1 << 22,
6,
writer_rank=writer_rank,
extra_writer_handler=extra_writer_handler,
blocking=blocking,
)

def create_single_reader_mq_broadcasters(self, reader_rank=0, blocking=False):
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

return MessageQueue.create_from_process_group_single_reader(
self.cpu_group, 1 << 22, 6, reader_rank=reader_rank, blocking=blocking
)

@property
def first_rank(self):
"""Return the global rank of the first process in the group"""
Expand Down Expand Up @@ -1155,15 +1176,30 @@ def init_distributed_environment(
# adjust the world size to take into account data parallelism
world_size = parallel_config.world_size_across_dp
ip = parallel_config.data_parallel_master_ip
port = parallel_config.get_next_dp_init_port()
if config.parallel_config.distributed_master_port > 0:
port = config.parallel_config.distributed_master_port
else:
port = parallel_config.get_next_dp_init_port()
distributed_init_method = get_distributed_init_method(ip, port)
logger.info(
"Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
world_size,
rank,
distributed_init_method,
)
elif config is not None and config.parallel_config.distributed_node_size > 1:
ip = config.parallel_config.distributed_master_ip
port = config.parallel_config.distributed_master_port
distributed_init_method = get_distributed_init_method(ip, port)
if not torch.distributed.is_initialized():
logger.info(
"world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
world_size,
rank,
local_rank,
distributed_init_method,
backend,
)
assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing "
"distributed environment"
Expand Down Expand Up @@ -1196,7 +1232,11 @@ def init_distributed_environment(
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend)
_NODE_COUNT = _node_count(_WORLD.cpu_group)
if config.parallel_config.distributed_node_size > 1:
# TODO: fix me (connection reset with below code across nodes)
_NODE_COUNT = config.parallel_config.distributed_node_size
else:
_NODE_COUNT = _node_count(_WORLD.cpu_group)
logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT)
else:
assert _WORLD.world_size == torch.distributed.get_world_size(), (
Expand Down
20 changes: 20 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,10 @@ class EngineArgs:
) = ParallelConfig.distributed_executor_backend
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
distributed_master_ip: str = ParallelConfig.distributed_master_ip
distributed_master_port: int = ParallelConfig.distributed_master_port
distributed_node_size: int = ParallelConfig.distributed_node_size
distributed_node_rank: int = ParallelConfig.distributed_node_rank
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
data_parallel_size: int = ParallelConfig.data_parallel_size
Expand Down Expand Up @@ -720,6 +724,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"-pp",
**parallel_kwargs["pipeline_parallel_size"],
)
parallel_group.add_argument(
"--distributed-master-ip", **parallel_kwargs["distributed_master_ip"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some comments as help filed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is included in ParallelConfig.doc,

)
parallel_group.add_argument(
"--distributed-master-port", **parallel_kwargs["distributed_master_port"]
)
parallel_group.add_argument(
"--distributed-node-size", **parallel_kwargs["distributed_node_size"]
)
parallel_group.add_argument(
"--distributed-node-rank", **parallel_kwargs["distributed_node_rank"]
)
parallel_group.add_argument(
"--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
)
Expand Down Expand Up @@ -1471,6 +1487,10 @@ def create_engine_config(
data_parallel_rank=self.data_parallel_rank or 0,
data_parallel_external_lb=data_parallel_external_lb,
data_parallel_size_local=data_parallel_size_local,
distributed_master_ip=self.distributed_master_ip,
distributed_master_port=self.distributed_master_port,
distributed_node_size=self.distributed_node_size,
distributed_node_rank=self.distributed_node_rank,
data_parallel_master_ip=data_parallel_address,
data_parallel_rpc_port=data_parallel_rpc_port,
data_parallel_backend=self.data_parallel_backend,
Expand Down
52 changes: 37 additions & 15 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,16 @@ def __init__(
self.model_executor.register_failure_callback(executor_fail_callback)

self.available_gpu_memory_for_kv_cache = -1
# No scheduler needed for distributed inference for folo
self.batch_queue_size = 0
self.batch_queue: (
deque[tuple[Future[ModelRunnerOutput], SchedulerOutput]] | None
) = None

self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
if self.vllm_config.parallel_config.distributed_node_rank > 0:
self._scheduler = None
return
# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
vllm_config
Expand Down Expand Up @@ -150,7 +159,7 @@ def __init__(
* vllm_config.parallel_config.decode_context_parallel_size
)

self.scheduler: SchedulerInterface = Scheduler(
self._scheduler = Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
structured_output_manager=self.structured_output_manager,
Expand All @@ -174,14 +183,10 @@ def __init__(
# schedule and execute batches, and is required by pipeline parallelism
# to eliminate pipeline bubbles.
self.batch_queue_size = self.model_executor.max_concurrent_batches
self.batch_queue: (
deque[tuple[Future[ModelRunnerOutput], SchedulerOutput]] | None
) = None
if self.batch_queue_size > 1:
logger.info("Batch queue is enabled with size %d", self.batch_queue_size)
self.batch_queue = deque(maxlen=self.batch_queue_size)

self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
if (
self.vllm_config.cache_config.enable_prefix_caching
or self.scheduler.get_kv_connector() is not None
Expand All @@ -199,6 +204,16 @@ def __init__(
self.step if self.batch_queue is None else self.step_with_batch_queue
)

@property
def scheduler(self) -> SchedulerInterface:
if not isinstance(self._scheduler, SchedulerInterface):
raise RuntimeError("Scheduler is not initialized")
return self._scheduler

@property
def scheduless_mode(self) -> bool:
return self._scheduler is None

def _initialize_kv_caches(
self, vllm_config: VllmConfig
) -> tuple[int, int, KVCacheConfig]:
Expand Down Expand Up @@ -393,19 +408,20 @@ def step_with_batch_queue(
return engine_core_outputs, model_executed

def shutdown(self):
self.structured_output_manager.clear_backend()
if self.structured_output_manager:
self.structured_output_manager.clear_backend()
if self.model_executor:
self.model_executor.shutdown()
if self.scheduler:
if not self.scheduless_mode:
self.scheduler.shutdown()

def profile(self, is_start: bool = True):
self.model_executor.profile(is_start)

def reset_mm_cache(self):
# NOTE: Since this is mainly for debugging, we don't attempt to
# re-sync the internal caches (P0 sender, P1 receiver)
if self.scheduler.has_unfinished_requests():
# re-sync the internal caches ((P0 sender, P1 receiver)
if not self.scheduless_mode and self.scheduler.has_unfinished_requests():
logger.warning(
"Resetting the multi-modal cache when requests are "
"in progress may lead to desynced internal caches."
Expand Down Expand Up @@ -805,12 +821,18 @@ def _init_data_parallel(self, vllm_config: VllmConfig):
def run_busy_loop(self):
"""Core busy loop of the EngineCore."""

# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
# 2) Step the engine core and return the outputs.
self._process_engine_step()
if not self.scheduless_mode:
# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
# 2) Step the engine core and return the outputs.
self._process_engine_step()
else:
# Loop until process is sent a SIGINT or SIGTERM
while True:
# No real scheduler for follower nodes
time.sleep(1)

def _process_input_queue(self):
"""Exits when an engine step needs to be performed."""
Expand Down
9 changes: 8 additions & 1 deletion vllm/v1/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,13 +996,20 @@ def wait_for_engine_startup(
conn_pending[0 if local else 1] -= 1
start_pending[0 if local else 1] += 1
engine.state = CoreEngineState.CONNECTED

elif (
status == "READY"
and engine.state == CoreEngineState.CONNECTED
and parallel_config.distributed_node_rank > 0
):
engine.state = CoreEngineState.READY
start_pending[0 if local else 1] -= 1
elif status == "READY" and engine.state == CoreEngineState.CONNECTED:
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
num_gpu_blocks = cache_config.num_gpu_blocks or 0
num_gpu_blocks += msg["num_gpu_blocks"]
cache_config.num_gpu_blocks = num_gpu_blocks

# In external DP LB mode, the coordinator address that the
# front-end procs connect to is obtained from rank 0 via
# one of the engine handshakes, and passed to the local
Expand Down
Loading