diff --git a/tests/distributed/test_multiproc_executor.py b/tests/distributed/test_multiproc_executor.py new file mode 100644 index 000000000000..e741a79bc4ed --- /dev/null +++ b/tests/distributed/test_multiproc_executor.py @@ -0,0 +1,437 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Integration tests for MultiprocExecutor at the executor level. +This test directly tests the executor without going through the LLM interface, +focusing on executor initialization, RPC calls, and distributed execution. +""" + +import multiprocessing +import os + +from tests.utils import multi_gpu_test +from vllm.config import VllmConfig +from vllm.engine.arg_utils import EngineArgs +from vllm.utils import get_open_port +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.executor.multiproc_executor import MultiprocExecutor + +MODEL = "facebook/opt-125m" + + +def create_vllm_config( + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + max_model_len: int = 256, + gpu_memory_utilization: float = 0.3, + distributed_executor_backend: str = "mp", + nnodes: int = 1, + node_rank: int = 0, + master_port: int = 0, +) -> VllmConfig: + """Create a VllmConfig for testing using EngineArgs.""" + engine_args = EngineArgs( + model=MODEL, + tensor_parallel_size=tensor_parallel_size, + pipeline_parallel_size=pipeline_parallel_size, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + ) + vllm_config = engine_args.create_engine_config() + + # Override distributed node settings if needed + if nnodes > 1 or node_rank > 0: + vllm_config.parallel_config.nnodes = nnodes + vllm_config.parallel_config.node_rank = node_rank + vllm_config.parallel_config.master_port = master_port + if nnodes > 1: + vllm_config.parallel_config.disable_custom_all_reduce = True + + return vllm_config + + +def create_test_scheduler_output(num_requests: int = 1) -> SchedulerOutput: + """Create a minimal SchedulerOutput for testing.""" + # This is a simplified version - in practice you'd need proper + # SchedulerOutput construction based on the actual vLLM v1 API + return SchedulerOutput( + scheduled_new_reqs=[], + scheduled_resumed_reqs=[], + scheduled_running_reqs=[], + num_scheduled_tokens={}, + total_num_scheduled_tokens=0, + ) + + +def test_multiproc_executor_initialization(): + """Test that MultiprocExecutor can be initialized with proper config.""" + vllm_config = create_vllm_config( + tensor_parallel_size=1, + pipeline_parallel_size=1, + ) + + # Create executor - this should initialize workers + executor = MultiprocExecutor(vllm_config=vllm_config) + + # Verify executor properties + assert executor.world_size == 1, "World size should be 1 for single GPU" + assert executor.local_world_size == 1, "Local world size should be 1" + assert hasattr(executor, "workers"), "Executor should have workers" + assert len(executor.workers) == 1, "Should have 1 worker for single GPU" + + # Clean up + executor.shutdown() + + +@multi_gpu_test(num_gpus=2) +def test_multiproc_executor_initialization_tensor_parallel(): + """Test MultiprocExecutor initialization with tensor parallelism.""" + vllm_config = create_vllm_config( + tensor_parallel_size=2, + pipeline_parallel_size=1, + ) + + # Create executor + executor = MultiprocExecutor(vllm_config=vllm_config) + + # Verify executor properties + assert executor.world_size == 2, "World size should be 2 for TP=2" + assert executor.local_world_size == 2, "Local world size should be 2" + assert len(executor.workers) == 2, "Should have 2 workers for TP=2" + + # Verify output rank calculation + output_rank = executor._get_output_rank() + assert output_rank == 0, "Output rank should be 0 for TP=2, PP=1" + + # Clean up + executor.shutdown() + + +@multi_gpu_test(num_gpus=2) +def test_multiproc_executor_collective_rpc(): + """Test collective RPC calls to all workers.""" + vllm_config = create_vllm_config( + tensor_parallel_size=2, + pipeline_parallel_size=1, + ) + + # Create executor + executor = MultiprocExecutor(vllm_config=vllm_config) + + try: + # Test check_health RPC - should work without errors + executor.check_health() + + # Test that RPC works correctly + # Note: We're just testing that the RPC mechanism works, + # not testing actual model execution here + assert not executor.is_failed, "Executor should not be in failed state" + + finally: + # Clean up + executor.shutdown() + + +def test_multiproc_executor_failure_callback(): + """Test failure callback registration and invocation.""" + vllm_config = create_vllm_config( + tensor_parallel_size=1, + pipeline_parallel_size=1, + ) + + executor = MultiprocExecutor(vllm_config=vllm_config) + + try: + # Test callback registration + callback_invoked = [] + + def test_callback(): + callback_invoked.append(True) + + # Register callback + executor.register_failure_callback(test_callback) + + # Callback should not be invoked yet + assert len(callback_invoked) == 0, "Callback should not be invoked immediately" + + # Simulate failure + executor.is_failed = True + + # Register another callback - should be invoked immediately + executor.register_failure_callback(test_callback) + assert len(callback_invoked) == 1, ( + "Callback should be invoked when executor is failed" + ) + + finally: + # Clean up + executor.shutdown() + + +@multi_gpu_test(num_gpus=2) +def test_multiproc_executor_worker_monitor(): + """Test that worker monitor is set up correctly.""" + vllm_config = create_vllm_config( + tensor_parallel_size=2, + pipeline_parallel_size=1, + ) + + executor = MultiprocExecutor(vllm_config=vllm_config) + + try: + # Verify all worker processes are alive + for worker in executor.workers: + assert worker.proc.is_alive(), f"Worker rank {worker.rank} should be alive" + + # Verify executor is not in failed state + assert not executor.is_failed, "Executor should not be in failed state" + + finally: + # Clean up + executor.shutdown() + + # After shutdown, workers should be terminated + import time + + time.sleep(0.5) # Give processes time to terminate + for worker in executor.workers: + assert not worker.proc.is_alive(), ( + f"Worker rank {worker.rank} should terminate after shutdown" + ) + + +@multi_gpu_test(num_gpus=2) +def test_multiproc_executor_get_response_message_queues(): + """Test message queue retrieval for different ranks.""" + vllm_config = create_vllm_config( + tensor_parallel_size=2, + pipeline_parallel_size=1, + ) + + executor = MultiprocExecutor(vllm_config=vllm_config) + + try: + # Get all message queues + all_queues = executor.get_response_mqs() + assert len(all_queues) == 2, "Should have 2 message queues for 2 workers" + + # Get message queue for specific rank + rank0_queue = executor.get_response_mqs(unique_reply_rank=0) + assert len(rank0_queue) == 1, "Should have 1 message queue for rank 0" + + rank1_queue = executor.get_response_mqs(unique_reply_rank=1) + assert len(rank1_queue) == 1, "Should have 1 message queue for rank 1" + + finally: + # Clean up + executor.shutdown() + + +def test_multiproc_executor_shutdown_cleanup(): + """Test that shutdown properly cleans up resources.""" + vllm_config = create_vllm_config( + tensor_parallel_size=1, + pipeline_parallel_size=1, + ) + + executor = MultiprocExecutor(vllm_config=vllm_config) + + # Verify executor is set up + assert hasattr(executor, "workers"), "Executor should have workers" + assert len(executor.workers) > 0, "Should have at least one worker" + + # Shutdown + executor.shutdown() + + # Verify cleanup + import time + + time.sleep(0.5) # Give processes time to terminate + + for worker in executor.workers: + assert not worker.proc.is_alive(), "Worker processes should be terminated" + + # Verify shutdown event is set + assert executor.shutdown_event.is_set(), "Shutdown event should be set" + + # Multiple shutdowns should be safe (idempotent) + executor.shutdown() + executor.shutdown() + + +@multi_gpu_test(num_gpus=4) +def test_multiproc_executor_pipeline_parallel(): + """Test MultiprocExecutor with pipeline parallelism.""" + vllm_config = create_vllm_config( + tensor_parallel_size=2, + pipeline_parallel_size=2, + ) + + executor = MultiprocExecutor(vllm_config=vllm_config) + + try: + # Verify executor properties + assert executor.world_size == 4, "World size should be 4 for TP=2, PP=2" + assert len(executor.workers) == 4, "Should have 4 workers" + + # Verify output rank calculation + # For TP=2, PP=2: output should be from the last PP stage (ranks 2-3) + # Specifically rank 2 (first rank of last PP stage) + output_rank = executor._get_output_rank() + assert output_rank == 2, "Output rank should be 2 (first rank of last PP stage)" + + # Verify max_concurrent_batches for pipeline parallel + assert executor.max_concurrent_batches == 2, ( + "Max concurrent batches should equal PP size" + ) + + finally: + # Clean up + executor.shutdown() + + +def test_multiproc_executor_properties(): + """Test various executor properties and configurations.""" + vllm_config = create_vllm_config( + tensor_parallel_size=1, + pipeline_parallel_size=1, + ) + + executor = MultiprocExecutor(vllm_config=vllm_config) + + try: + # Test supports_pp property + assert MultiprocExecutor.supports_pp is True, ( + "MultiprocExecutor should support pipeline parallelism" + ) + + # Test world_size calculation + assert executor.world_size == ( + executor.parallel_config.tensor_parallel_size + * executor.parallel_config.pipeline_parallel_size + ), "World size should equal TP * PP" + + # Test local_world_size calculation + assert executor.local_world_size == ( + executor.parallel_config.world_size // executor.parallel_config.nnodes + ), "Local world size should be world_size / nnodes" + + finally: + # Clean up + executor.shutdown() + + +@multi_gpu_test(num_gpus=4) +def test_multiproc_executor_multi_node(): + """ + Test MultiprocExecutor with multi-node configuration. + This simulates 2 nodes with TP=4: + - Node 0 (rank 0): Uses GPUs 0,1 (CUDA_VISIBLE_DEVICES=0,1) with TP=2 + - Node 1 (rank 1): Uses GPUs 2,3 (CUDA_VISIBLE_DEVICES=2,3) with TP=2 + Total world_size = 4, nnodes = 2 + """ + port = get_open_port() + # symm_mem does not work for simulating multi instance in single node + os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0" + + def run_node(node_rank: int, result_queue: multiprocessing.Queue, port: int): + """Run a single node's executor.""" + executor = None + try: + # Set CUDA_VISIBLE_DEVICES for this node + if node_rank == 0: + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" + else: + os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" + + # Create config for this node + vllm_config = create_vllm_config( + tensor_parallel_size=4, # Total TP across all nodes + pipeline_parallel_size=1, + nnodes=2, # 2 nodes + node_rank=node_rank, + master_port=port, # same port + ) + + # Create executor for this node + executor = MultiprocExecutor(vllm_config=vllm_config) + + # Verify node-specific properties + assert executor.world_size == 4, ( + f"World size should be 4 on node {node_rank}" + ) + assert executor.local_world_size == 2, ( + f"Local world size should be 2 on node {node_rank}" + ) + assert len(executor.workers) == 2, ( + f"Should have 2 local workers on node {node_rank}" + ) + + # Verify worker ranks are correct for this node + expected_ranks = [node_rank * 2, node_rank * 2 + 1] + actual_ranks = sorted([w.rank for w in executor.workers]) + assert actual_ranks == expected_ranks, ( + f"Node {node_rank} should have workers " + f"with ranks {expected_ranks}, got {actual_ranks}" + ) + # Verify all workers are alive + for worker in executor.workers: + assert worker.proc.is_alive(), ( + f"Worker rank {worker.rank} should be alive on node {node_rank}" + ) + # executor.gen + # Put success result in queue BEFORE shutdown to avoid hanging + result_queue.put({"node": node_rank, "success": True}) + import time + + time.sleep(2) + executor.shutdown() + except Exception as e: + # Put failure result in queue + result_queue.put({"node": node_rank, "success": False, "error": str(e)}) + raise e + finally: + if executor is not None: + executor.shutdown() + + # Create a queue to collect results from both processes + result_queue: multiprocessing.Queue[dict[str, int | bool]] = multiprocessing.Queue() + + # Start both node processes + processes = [] + for node_rank in range(2): + p = multiprocessing.Process( + target=run_node, + args=(node_rank, result_queue, port), + name=f"Node{node_rank}", + ) + p.start() + processes.append(p) + + # Wait for both processes to complete + all_completed = True + for p in processes: + p.join(timeout=60) + if p.is_alive(): + p.terminate() + p.join(timeout=20) + if p.is_alive(): + p.kill() + p.join() + all_completed = False + + # Check results from both nodes + results: list[dict[str, int | bool]] = [] + while len(results) < 2: + try: + result = result_queue.get(timeout=1) + results.append(result) + except Exception: + pass + assert all_completed, "Not all processes completed successfully" + assert len(results) == 2, f"Expected 2 results, got {len(results)}" + assert results[0]["success"], f"Node 0 failed: {results[0]}" + assert results[1]["success"], f"Node 1 failed: {results[1]}" diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index b19c8beeae3d..6dd89760c49f 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -210,6 +210,18 @@ class ParallelConfig: class is dynamically inherited by the worker class. This is used to inject new attributes and methods to the worker class for use in collective_rpc calls.""" + master_addr: str = "127.0.0.1" + """distributed master address for multi-node distributed + inference when distributed_executor_backend is mp.""" + master_port: int = 29501 + """distributed master port for multi-node distributed + inference when distributed_executor_backend is mp.""" + node_rank: int = 0 + """distributed node rank for multi-node distributed + inference when distributed_executor_backend is mp.""" + nnodes: int = 1 + """num of nodes for multi-node distributed + inference when distributed_executor_backend is mp.""" world_size: int = Field(init=False) """world_size is TPxPP, it affects the number of workers we create.""" @@ -387,6 +399,23 @@ def use_sequence_parallel_moe(self) -> bool: and self.data_parallel_size > 1 ) + @property + def node_rank_within_dp(self) -> int: + return self.node_rank % self.nnodes_within_dp + + @property + def nnodes_within_dp(self) -> int: + if self.nnodes == 1: + return 1 + data_parallel_node_size = ( + self.data_parallel_size // self.data_parallel_size_local + ) + return self.nnodes // data_parallel_node_size + + @property + def local_world_size(self) -> int: + return self.world_size // self.nnodes_within_dp + @staticmethod def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool: tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="cpu") @@ -528,6 +557,8 @@ def __post_init__(self) -> None: ray_found = ray_utils.ray_is_available() if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD: backend = "uni" + elif current_platform.is_cuda() and self.nnodes > 1: + backend = "mp" elif ( current_platform.is_cuda() and cuda_device_count_stateless() < self.world_size @@ -565,6 +596,10 @@ def __post_init__(self) -> None: "max_parallel_loading_workers is currently " "not supported and will be ignored." ) + if self.distributed_executor_backend != "mp" and self.nnodes > 1: + raise ValueError( + "nnodes > 1 can only be set when distributed exectuor backend is mp." + ) @property def use_ray(self) -> bool: @@ -607,6 +642,11 @@ def _verify_args(self) -> Self: "Disabled the custom all-reduce kernel because it is not " "supported on current platform." ) + if self.nnodes > 1: + self.disable_custom_all_reduce = True + logger.debug( + "Disabled the custom all-reduce since we are running on multi-node." + ) if self.ray_workers_use_nsight and not self.use_ray: raise ValueError( "Unable to use nsight profiling unless workers run with Ray." diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 5046cac2e90a..052df19e34d7 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -8,7 +8,7 @@ from multiprocessing import shared_memory from pickle import PickleBuffer from threading import Event -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from unittest.mock import patch import torch @@ -602,13 +602,87 @@ def broadcast_object(self, obj=None): return obj return self.dequeue() + @staticmethod + def create_from_process_group_single_reader( + pg: ProcessGroup, + max_chunk_bytes, + max_chunks, + reader_rank: int = 0, + blocking: bool = False, + ) -> tuple["MessageQueue", list[Handle]]: + """ + Creates a MessageQueue for a process group with a single reader. + + This method is designed for scenarios where only one process (the reader) + will consume messages, and all other processes are writers. It sets up + the shared memory buffer and communication handles accordingly, and + gathers the handles from all processes to the reader. + + Args: + pg (ProcessGroup): The torch distributed process group. + max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer. + max_chunks (int): Maximum number of chunks in the buffer. + reader_rank (int, optional): The global rank that will act as the reader. + Defaults to 0. + blocking (bool, optional): If True, blocks until all processes are ready. + Defaults to False. + + Returns: + tuple[MessageQueue, list[Handle]]: + The MessageQueue instance for the calling process, + and a list of handles (only non-empty for the reader process). + """ + local_size = torch.cuda.device_count() + rank = dist.get_rank() + same_node = rank // local_size == reader_rank // local_size + buffer_io = MessageQueue( + n_reader=1, + n_local_reader=1 if same_node else 0, + max_chunk_bytes=max_chunk_bytes, + max_chunks=max_chunks, + ) + handle = buffer_io.export_handle() + handles = [None] * dist.get_world_size(pg) if rank == reader_rank else None + dist.gather_object(handle, handles, dst=reader_rank, group=pg) + if blocking: + buffer_io.wait_until_ready() + return buffer_io, cast(list[Handle], handles or []) + @staticmethod def create_from_process_group( pg: ProcessGroup | StatelessProcessGroup, max_chunk_bytes, max_chunks, - writer_rank=0, + writer_rank: int = 0, + external_writer_handle=None, + blocking: bool = True, ) -> "MessageQueue": + """ + Creates a MessageQueue for a distributed process group with one writer and + multiple readers. + + This method is designed for scenarios where one process (the writer) sends + messages, and all other processes (the readers) receive messages. It sets up + the shared memory buffer and socket communication handles accordingly, and + broadcasts the handle from the writer to all readers. + + Args: + pg (ProcessGroup | StatelessProcessGroup): The torch distributed process + group. + max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer. + max_chunks (int): Maximum number of chunks in the buffer. + writer_rank (int, optional): The global rank that will act as the writer. + Defaults to 0. + external_writer_handle (Handle, optional): Used when there is a handle + from an external Message Queue. If provided, use this handle to init + PG writer message queue instead of creating a new one. Defaults to None. + blocking (bool, optional): If True, blocks until all processes are ready. + Defaults to True. + + Returns: + MessageQueue: The MessageQueue instance for the calling process. + + """ if isinstance(pg, ProcessGroup): group_rank = dist.get_rank(pg) group_world_size = dist.get_world_size(pg) @@ -617,23 +691,26 @@ def create_from_process_group( group_rank = pg.rank group_world_size = pg.world_size global_ranks = list(range(pg.world_size)) - from vllm.distributed.parallel_state import in_the_same_node_as status = in_the_same_node_as(pg, source_rank=writer_rank) - same_node_ranks = [i for i, s in enumerate(status) if s] - n_reader = group_world_size - 1 - n_local_reader = len(same_node_ranks) - 1 - local_reader_ranks = [i for i in same_node_ranks if i != writer_rank] - buffer_io: MessageQueue if group_rank == writer_rank: - buffer_io = MessageQueue( - n_reader=n_reader, - n_local_reader=n_local_reader, - local_reader_ranks=local_reader_ranks, - max_chunk_bytes=max_chunk_bytes, - max_chunks=max_chunks, - ) + if external_writer_handle is not None: + buffer_io = MessageQueue.create_from_handle( + external_writer_handle, group_rank + ) + else: + same_node_ranks = [i for i, s in enumerate(status) if s] + n_reader = group_world_size - 1 + n_local_reader = len(same_node_ranks) - 1 + local_reader_ranks = [i for i in same_node_ranks if i != writer_rank] + buffer_io = MessageQueue( + n_reader=n_reader, + n_local_reader=n_local_reader, + local_reader_ranks=local_reader_ranks, + max_chunk_bytes=max_chunk_bytes, + max_chunks=max_chunks, + ) handle = buffer_io.export_handle() if isinstance(pg, ProcessGroup): dist.broadcast_object_list( @@ -651,5 +728,6 @@ def create_from_process_group( else: handle = pg.broadcast_obj(None, writer_rank) buffer_io = MessageQueue.create_from_handle(handle, group_rank) - buffer_io.wait_until_ready() + if blocking: + buffer_io.wait_until_ready() return buffer_io diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a9b01e82562b..571fbbcc6a1e 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -385,6 +385,33 @@ def __init__( torch.ops._C, "init_shm_manager" ) + def create_mq_broadcaster( + self, writer_rank=0, external_writer_handle=None, blocking=True + ): + from vllm.distributed.device_communicators.shm_broadcast import MessageQueue + + return MessageQueue.create_from_process_group( + self.cpu_group, + 1 << 22, + 6, + writer_rank=writer_rank, + external_writer_handle=external_writer_handle, + blocking=blocking, + ) + + def create_single_reader_mq_broadcasters( + self, reader_rank_in_group=0, blocking=False + ): + from vllm.distributed.device_communicators.shm_broadcast import MessageQueue + + return MessageQueue.create_from_process_group_single_reader( + self.cpu_group, + 1 << 22, + 6, + reader_rank=self.ranks[reader_rank_in_group], + blocking=blocking, + ) + @property def first_rank(self): """Return the global rank of the first process in the group""" @@ -997,6 +1024,7 @@ def combine( _WORLD: GroupCoordinator | None = None +_INNER_DP_WORLD: GroupCoordinator | None = None _NODE_COUNT: int | None = None @@ -1005,6 +1033,11 @@ def get_world_group() -> GroupCoordinator: return _WORLD +def get_inner_dp_world_group() -> GroupCoordinator: + assert _INNER_DP_WORLD is not None, "inner dp world group is not initialized" + return _INNER_DP_WORLD + + def init_world_group( ranks: list[int], local_rank: int, backend: str ) -> GroupCoordinator: @@ -1023,12 +1056,13 @@ def init_model_parallel_group( backend: str, use_message_queue_broadcaster: bool = False, group_name: str | None = None, + use_device_communicator: bool = True, ) -> GroupCoordinator: return GroupCoordinator( group_ranks=group_ranks, local_rank=local_rank, torch_distributed_backend=backend, - use_device_communicator=True, + use_device_communicator=use_device_communicator, use_message_queue_broadcaster=use_message_queue_broadcaster, group_name=group_name, ) @@ -1143,7 +1177,14 @@ def init_distributed_environment( from vllm.config import get_current_vllm_config config = get_current_vllm_config() - if ( + if config is not None and config.parallel_config.nnodes > 1: + parallel_config = config.parallel_config + ip = parallel_config.master_addr + rank = parallel_config.data_parallel_rank * world_size + rank + world_size = parallel_config.world_size_across_dp + port = parallel_config.master_port + distributed_init_method = get_distributed_init_method(ip, port) + elif ( config is not None and config.parallel_config.data_parallel_size > 1 and config.parallel_config.distributed_executor_backend != "external_launcher" @@ -1164,6 +1205,14 @@ def init_distributed_environment( distributed_init_method, ) if not torch.distributed.is_initialized(): + logger.info( + "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s", + world_size, + rank, + local_rank, + distributed_init_method, + backend, + ) assert distributed_init_method is not None, ( "distributed_init_method must be provided when initializing " "distributed environment" @@ -1192,16 +1241,36 @@ def init_distributed_environment( # local rank not set, this usually happens in single-node # setting, where we can use rank as local rank local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank - global _WORLD, _NODE_COUNT + global _WORLD, _NODE_COUNT, _INNER_DP_WORLD if _WORLD is None: ranks = list(range(torch.distributed.get_world_size())) _WORLD = init_world_group(ranks, local_rank, backend) - _NODE_COUNT = _node_count(_WORLD.cpu_group) + if config.parallel_config.nnodes > 1: + _NODE_COUNT = config.parallel_config.nnodes + else: + _NODE_COUNT = _node_count(_WORLD.cpu_group) logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT) else: assert _WORLD.world_size == torch.distributed.get_world_size(), ( "world group already initialized with a different world size" ) + if config.parallel_config.nnodes_within_dp > 1: + if parallel_config.data_parallel_size > 1: + world_size_inner_dp = parallel_config.world_size + group_ranks = [ + [dp_rank * world_size_inner_dp + i for i in range(world_size_inner_dp)] + for dp_rank in range(parallel_config.data_parallel_size) + ] + _INNER_DP_WORLD = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name="inner_dp_world", + use_device_communicator=False, + ) + else: + _INNER_DP_WORLD = _WORLD def initialize_model_parallel( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b12b7082af62..2e7f9df2143c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -383,6 +383,10 @@ class EngineArgs: ) = ParallelConfig.distributed_executor_backend # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size + master_addr: str = ParallelConfig.master_addr + master_port: int = ParallelConfig.master_port + nnodes: int = ParallelConfig.nnodes + node_rank: int = ParallelConfig.node_rank tensor_parallel_size: int = ParallelConfig.tensor_parallel_size decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size @@ -393,6 +397,7 @@ class EngineArgs: data_parallel_address: str | None = None data_parallel_rpc_port: int | None = None data_parallel_hybrid_lb: bool = False + data_parallel_external_lb: bool = False data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel all2all_backend: str | None = ParallelConfig.all2all_backend @@ -743,6 +748,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "-pp", **parallel_kwargs["pipeline_parallel_size"], ) + parallel_group.add_argument("--master-addr", **parallel_kwargs["master_addr"]) + parallel_group.add_argument("--master-port", **parallel_kwargs["master_port"]) + parallel_group.add_argument("--nnodes", "-n", **parallel_kwargs["nnodes"]) + parallel_group.add_argument("--node-rank", "-r", **parallel_kwargs["node_rank"]) parallel_group.add_argument( "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"] ) @@ -797,7 +806,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help='Backend for data parallel, either "mp" or "ray".', ) parallel_group.add_argument( - "--data-parallel-hybrid-lb", **parallel_kwargs["data_parallel_hybrid_lb"] + "--data-parallel-hybrid-lb", + "-dph", + **parallel_kwargs["data_parallel_hybrid_lb"], + ) + parallel_group.add_argument( + "--data-parallel-external-lb", + "-dpe", + **parallel_kwargs["data_parallel_external_lb"], ) parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"] @@ -1416,12 +1432,56 @@ def create_engine_config( assert not headless or not self.data_parallel_hybrid_lb, ( "data_parallel_hybrid_lb is not applicable in headless mode" ) - - data_parallel_external_lb = self.data_parallel_rank is not None + assert not (self.data_parallel_hybrid_lb and self.data_parallel_external_lb), ( + "data_parallel_hybrid_lb and data_parallel_external_lb cannot both be True." + ) + assert self.data_parallel_backend == "mp" or self.nnodes == 1, ( + "nnodes > 1 is only supported with data_parallel_backend=mp" + ) + inferred_data_parallel_rank = 0 + if self.nnodes > 1: + world_size = ( + self.data_parallel_size + * self.pipeline_parallel_size + * self.tensor_parallel_size + ) + world_size_within_dp = ( + self.pipeline_parallel_size * self.tensor_parallel_size + ) + local_world_size = world_size // self.nnodes + assert world_size % self.nnodes == 0, ( + f"world_size={world_size} must be divisible by nnodes={self.nnodes}." + ) + assert self.node_rank < self.nnodes, ( + f"node_rank={self.node_rank} must be less than nnodes={self.nnodes}." + ) + inferred_data_parallel_rank = ( + self.node_rank * local_world_size + ) // world_size_within_dp + if self.data_parallel_size > 1 and self.data_parallel_external_lb: + self.data_parallel_rank = inferred_data_parallel_rank + logger.info( + "Inferred data_parallel_rank %d from node_rank %d for external lb", + self.data_parallel_rank, + self.node_rank, + ) + elif self.data_parallel_size_local is None: + # Infer data parallel size local for internal dplb: + self.data_parallel_size_local = max( + local_world_size // world_size_within_dp, 1 + ) + data_parallel_external_lb = ( + self.data_parallel_external_lb or self.data_parallel_rank is not None + ) # Local DP rank = 1, use pure-external LB. if data_parallel_external_lb: + assert self.data_parallel_rank is not None, ( + "data_parallel_rank or node_rank must be spefified if " + "data_parallel_external_lb is enable." + ) assert self.data_parallel_size_local in (1, None), ( - "data_parallel_size_local must be 1 when data_parallel_rank is set" + "data_parallel_size_local must be 1 or None when data_parallel_rank " + "is set" ) data_parallel_size_local = 1 # Use full external lb if we have local_size of 1. @@ -1435,6 +1495,11 @@ def create_engine_config( if self.data_parallel_hybrid_lb and data_parallel_size_local == 1: # Use full external lb if we have local_size of 1. + logger.warning( + "data_parallel_hybrid_lb is not eligible when " + "data_parallel_size_local = 1, autoswitch to " + "data_parallel_external_lb." + ) data_parallel_external_lb = True self.data_parallel_hybrid_lb = False @@ -1442,7 +1507,15 @@ def create_engine_config( # Disable hybrid LB mode if set for a single node self.data_parallel_hybrid_lb = False - self.data_parallel_rank = self.data_parallel_start_rank or 0 + self.data_parallel_rank = ( + self.data_parallel_start_rank or inferred_data_parallel_rank + ) + if self.nnodes > 1: + logger.info( + "Inferred data_parallel_rank %d from node_rank %d", + self.data_parallel_rank, + self.node_rank, + ) else: assert not self.data_parallel_hybrid_lb, ( "data_parallel_size_local must be set to use data_parallel_hybrid_lb." @@ -1472,7 +1545,9 @@ def create_engine_config( "data_parallel_backend can only be ray or mp, got %s", self.data_parallel_backend, ) - data_parallel_address = ParallelConfig.data_parallel_master_ip + data_parallel_address = ( + self.master_addr or ParallelConfig.data_parallel_master_ip + ) else: data_parallel_address = self.data_parallel_address @@ -1501,6 +1576,10 @@ def create_engine_config( data_parallel_rank=self.data_parallel_rank or 0, data_parallel_external_lb=data_parallel_external_lb, data_parallel_size_local=data_parallel_size_local, + master_addr=self.master_addr, + master_port=self.master_port, + nnodes=self.nnodes, + node_rank=self.node_rank, data_parallel_master_ip=data_parallel_address, data_parallel_rpc_port=data_parallel_rpc_port, data_parallel_backend=self.data_parallel_backend, diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 2678658dd126..96608f360e17 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -24,6 +24,7 @@ from vllm.v1.engine.core import EngineCoreProc from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines from vllm.v1.executor import Executor +from vllm.v1.executor.multiproc_executor import MultiprocExecutor from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure @@ -97,18 +98,40 @@ def run_headless(args: argparse.Namespace): if local_engine_count <= 0: raise ValueError("data_parallel_size_local must be > 0 in headless mode") - host = parallel_config.data_parallel_master_ip - port = engine_args.data_parallel_rpc_port # add to config too - handshake_address = get_tcp_uri(host, port) + shutdown_requested = False # Catch SIGTERM and SIGINT to allow graceful shutdown. def signal_handler(signum, frame): + nonlocal shutdown_requested logger.debug("Received %d signal.", signum) - raise SystemExit + if not shutdown_requested: + shutdown_requested = True + raise SystemExit signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) + if parallel_config.node_rank_within_dp > 0: + from vllm.version import __version__ as VLLM_VERSION + + # Run headless workers (for multi-node PP/TP). + host = parallel_config.master_addr + head_node_address = f"{host}:{parallel_config.master_port}" + logger.info( + "Launching vLLM (v%s) headless multiproc executor, " + "with head node address %s for torch.distributed process group.", + VLLM_VERSION, + head_node_address, + ) + + executor = MultiprocExecutor(vllm_config, monitor_workers=False) + executor.start_worker_monitor(inline=True) + return + + host = parallel_config.data_parallel_master_ip + port = parallel_config.data_parallel_rpc_port + handshake_address = get_tcp_uri(host, port) + logger.info( "Launching %d data parallel engine(s) in headless mode, " "with head node address %s.", diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index e74519b21aa6..d65cad7af03d 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -183,15 +183,19 @@ def set_device_control_env_var( for engine subprocess. """ world_size = vllm_config.parallel_config.world_size + local_world_size = vllm_config.parallel_config.local_world_size evar = current_platform.device_control_env_var - value = get_device_indices(evar, local_dp_rank, world_size) + value = get_device_indices(evar, local_dp_rank, world_size, local_world_size) with patch.dict(os.environ, values=((evar, value),)): yield def get_device_indices( - device_control_env_var: str, local_dp_rank: int, world_size: int + device_control_env_var: str, + local_dp_rank: int, + world_size: int, + local_world_size: int | None = None, ): """ Returns a comma-separated string of device indices for the specified @@ -200,10 +204,15 @@ def get_device_indices( For example, if world_size=2 and local_dp_rank=1, and there are 4 devices, this will select devices 2 and 3 for local_dp_rank=1. """ + if local_world_size is None: + local_world_size = world_size try: value = ",".join( str(current_platform.device_id_to_physical_device_id(i)) - for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * world_size) + for i in range( + local_dp_rank * world_size, + local_dp_rank * world_size + local_world_size, + ) ) except IndexError as e: raise Exception( diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 1e249161c688..924ee1443040 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -10,7 +10,7 @@ import traceback import weakref from collections import deque -from collections.abc import Callable +from collections.abc import Callable, Sequence from concurrent.futures import Future, InvalidStateError from contextlib import suppress from dataclasses import dataclass @@ -33,6 +33,7 @@ from vllm.distributed.parallel_state import ( get_dp_group, get_ep_group, + get_inner_dp_world_group, get_pp_group, get_tp_group, ) @@ -89,6 +90,10 @@ def wait_for_response(self, get_response: Callable): class MultiprocExecutor(Executor): supports_pp: bool = True + def __init__(self, vllm_config: VllmConfig, monitor_workers: bool = True): + self.monitor_workers = monitor_workers + super().__init__(vllm_config) + def _init_executor(self) -> None: # Call self.shutdown at exit to clean up # and ensure workers will be terminated. @@ -98,6 +103,12 @@ def _init_executor(self) -> None: self.failure_callback: FailureCallback | None = None self.world_size = self.parallel_config.world_size + assert self.world_size % self.parallel_config.nnodes_within_dp == 0, ( + f"global world_size ({self.parallel_config.world_size}) must be " + f"divisible by nnodes_within_dp " + f"({self.parallel_config.nnodes_within_dp}). " + ) + self.local_world_size = self.parallel_config.local_world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size pp_parallel_size = self.parallel_config.pipeline_parallel_size assert self.world_size == tensor_parallel_size * pp_parallel_size, ( @@ -115,27 +126,37 @@ def _init_executor(self) -> None: distributed_init_method = get_distributed_init_method( get_loopback_ip(), get_open_port() ) - + self.rpc_broadcast_mq: MessageQueue | None = None + scheduler_output_handle: Handle | None = None # Initialize worker and set up message queues for SchedulerOutputs # and ModelRunnerOutputs - max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024 - self.rpc_broadcast_mq = MessageQueue( - self.world_size, self.world_size, max_chunk_bytes=max_chunk_bytes - ) - scheduler_output_handle = self.rpc_broadcast_mq.export_handle() - + if self.parallel_config.node_rank_within_dp == 0: + # For leader node within each dp rank, + # each dp will have its own leader multiproc executor. + max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024 + self.rpc_broadcast_mq = MessageQueue( + self.world_size, + self.local_world_size, + max_chunk_bytes=max_chunk_bytes, + connect_ip=self.parallel_config.master_addr, + ) + scheduler_output_handle = self.rpc_broadcast_mq.export_handle() # Create workers context = get_mp_context() shared_worker_lock = context.Lock() unready_workers: list[UnreadyWorkerProcHandle] = [] success = False try: - for rank in range(self.world_size): + global_start_rank = ( + self.local_world_size * self.parallel_config.node_rank_within_dp + ) + for local_rank in range(self.local_world_size): + global_rank = global_start_rank + local_rank unready_workers.append( WorkerProc.make_worker_process( vllm_config=self.vllm_config, - local_rank=rank, - rank=rank, + local_rank=local_rank, + rank=global_rank, distributed_init_method=distributed_init_method, input_shm_handle=scheduler_output_handle, shared_worker_lock=shared_worker_lock, @@ -144,15 +165,38 @@ def _init_executor(self) -> None: # Workers must be created before wait_for_ready to avoid # deadlock, since worker.init_device() does a device sync. + + # Wait for all local workers to be ready. self.workers = WorkerProc.wait_for_ready(unready_workers) + # Start background thread to monitor worker health if not in headless mode. + if self.monitor_workers: + self.start_worker_monitor() + + self.response_mqs = [] + # Only leader node have remote response mqs + if self.parallel_config.node_rank_within_dp == 0: + for rank in range(self.world_size): + if rank < self.local_world_size: + local_message_queue = self.workers[rank].worker_response_mq + assert local_message_queue is not None + self.response_mqs.append(local_message_queue) + else: + remote_message_queue = self.workers[0].peer_worker_response_mqs[ + rank + ] + assert remote_message_queue is not None + self.response_mqs.append(remote_message_queue) + # Ensure message queues are ready. Will deadlock if re-ordered # Must be kept consistent with the WorkerProc. - self.rpc_broadcast_mq.wait_until_ready() - for w in self.workers: - w.worker_response_mq.wait_until_ready() - self.start_worker_monitor() + # Wait for all input mqs to be ready. + if self.rpc_broadcast_mq is not None: + self.rpc_broadcast_mq.wait_until_ready() + # Wait for all remote response mqs to be ready. + for response_mq in self.response_mqs: + response_mq.wait_until_ready() success = True finally: if not success: @@ -167,7 +211,7 @@ def _init_executor(self) -> None: self.output_rank = self._get_output_rank() - def start_worker_monitor(self): + def start_worker_monitor(self, inline=False) -> None: workers = self.workers self_ref = weakref.ref(self) @@ -191,9 +235,13 @@ def monitor_workers(): _self.failure_callback = None callback() - Thread( - target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor" - ).start() + if not inline: + Thread( + target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor" + ).start() + return + + monitor_workers() def register_failure_callback(self, callback: FailureCallback): if self.is_failed: @@ -246,7 +294,9 @@ def collective_rpc( # type: ignore[override] ) -> Any | list[Any] | Future[Any | list[Any]]: """Returns single result if unique_reply_rank and/or kv_output_aggregator is provided, otherwise list.""" - + assert self.rpc_broadcast_mq is not None, ( + "collective_rpc should not be called on follower node" + ) if self.is_failed: raise RuntimeError("Executor failed.") @@ -268,20 +318,20 @@ def collective_rpc( # type: ignore[override] send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL) self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, output_rank)) - workers = ( - (self.workers[output_rank],) if output_rank is not None else self.workers - ) + response_mqs: Sequence[MessageQueue] = self.response_mqs + if output_rank is not None: + response_mqs = (response_mqs[output_rank],) shutdown_event = self.shutdown_event def get_response(): responses = [] - for w in workers: + for mq in response_mqs: dequeue_timeout = ( None if deadline is None else (deadline - time.monotonic()) ) try: - status, result = w.worker_response_mq.dequeue( + status, result = mq.dequeue( timeout=dequeue_timeout, cancel=shutdown_event ) except TimeoutError as e: @@ -390,17 +440,26 @@ class UnreadyWorkerProcHandle: class WorkerProcHandle: proc: BaseProcess rank: int - worker_response_mq: MessageQueue # The worker process writes to this MQ + # The worker process writes to this MQ in single-node mode + worker_response_mq: MessageQueue | None + # This is only non empty on driver node, + # the peer worker process i writes to MQ + # `peer_worker_response_mqs[i]` + peer_worker_response_mqs: list[MessageQueue | None] death_writer: Connection | None = None @classmethod def from_unready_handle( - cls, unready_handle: UnreadyWorkerProcHandle, worker_response_mq: MessageQueue + cls, + unready_handle: UnreadyWorkerProcHandle, + worker_response_mq: MessageQueue | None, + peer_worker_response_mqs: list[MessageQueue | None], ) -> "WorkerProcHandle": return cls( proc=unready_handle.proc, rank=unready_handle.rank, worker_response_mq=worker_response_mq, + peer_worker_response_mqs=peer_worker_response_mqs, death_writer=unready_handle.death_writer, ) @@ -410,6 +469,38 @@ class WorkerProc: READY_STR = "READY" + def _init_message_queues( + self, input_shm_handle: Handle, vllm_config: VllmConfig + ) -> None: + if vllm_config.parallel_config.nnodes_within_dp == 1: + # Initialize MessageQueue for receiving SchedulerOutput + self.rpc_broadcast_mq = MessageQueue.create_from_handle( + input_shm_handle, self.worker.rank + ) + + # Initializes a message queue for sending the model output + self.worker_response_mq: MessageQueue = MessageQueue(1, 1) + self.peer_response_handles = [] + else: + # Initialize remote MessageQueue for receiving SchedulerOutput across nodes + self.rpc_broadcast_mq = get_inner_dp_world_group().create_mq_broadcaster( + external_writer_handle=input_shm_handle, + # Since there is external_writer_handle from executor proc, + # where the ready signal from actual writer is sent out of the + # create_mq_broadcaster method and after this setup, we make it + # non blocking. The handshake will be triggered when + # worker.rpc_broadcast_mq.wait_until_ready() is called + blocking=False, + ) + # Initializes remote message queue for sending the model output to the + # driver worker, exposing peer_response_handles for driver worker + # that include handles for all ranks + self.worker_response_mq, self.peer_response_handles = ( + get_inner_dp_world_group().create_single_reader_mq_broadcasters( + reader_rank_in_group=0 + ) + ) + def __init__( self, vllm_config: VllmConfig, @@ -420,13 +511,15 @@ def __init__( shared_worker_lock: LockType, ): self.rank = rank - wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank) + wrapper = WorkerWrapperBase( + vllm_config=vllm_config, rpc_rank=local_rank, global_rank=rank + ) # TODO: move `init_worker` to executor level as a collective rpc call all_kwargs: list[dict] = [ {} for _ in range(vllm_config.parallel_config.world_size) ] is_driver_worker = rank % vllm_config.parallel_config.tensor_parallel_size == 0 - all_kwargs[rank] = { + all_kwargs[local_rank] = { "vllm_config": vllm_config, "local_rank": local_rank, "rank": rank, @@ -437,14 +530,6 @@ def __init__( wrapper.init_worker(all_kwargs) self.worker = wrapper - # Initialize MessageQueue for receiving SchedulerOutput - self.rpc_broadcast_mq = MessageQueue.create_from_handle( - input_shm_handle, self.worker.rank - ) - - # Initializes a message queue for sending the model output - self.worker_response_mq = MessageQueue(1, 1) - scheduler_config = vllm_config.scheduler_config self.use_async_scheduling = scheduler_config.async_scheduling if self.use_async_scheduling: @@ -465,6 +550,7 @@ def __init__( ) # Load model + self._init_message_queues(input_shm_handle, vllm_config) self.worker.load_model() # Enable environment variable cache (e.g. assume no more @@ -511,6 +597,27 @@ def make_worker_process( # death_reader in child will get EOFError return UnreadyWorkerProcHandle(proc, rank, reader, death_writer) + @staticmethod + def wait_for_response_handle_ready( + handles: dict[str, Any], proc_handle: UnreadyWorkerProcHandle + ) -> WorkerProcHandle: + response_handle = handles["handle"] + worker_response_mq: MessageQueue | None = None + if len(response_handle.local_reader_ranks) > 0: + worker_response_mq = MessageQueue.create_from_handle(response_handle, 0) + peer_response_handles = handles["peer_response_handles"] + peer_worker_response_mqs = [ + MessageQueue.create_from_handle(handle, -1) + if handle.remote_subscribe_addr is not None + else None + for handle in peer_response_handles + ] + return WorkerProcHandle.from_unready_handle( + proc_handle, + worker_response_mq, + peer_worker_response_mqs=peer_worker_response_mqs, + ) + @staticmethod def wait_for_ready( unready_proc_handles: list[UnreadyWorkerProcHandle], @@ -536,16 +643,10 @@ def wait_for_ready( if response["status"] != "READY": raise e - # Extract the message queue handle. - worker_response_mq = MessageQueue.create_from_handle( - response["handle"], 0 - ) - ready_proc_handles[unready_proc_handle.rank] = ( - WorkerProcHandle.from_unready_handle( - unready_proc_handle, worker_response_mq - ) + idx = unready_proc_handle.rank % len(ready_proc_handles) + ready_proc_handles[idx] = WorkerProc.wait_for_response_handle_ready( + response, unready_proc_handle ) - except EOFError: e.__suppress_context__ = True raise e from None @@ -617,12 +718,14 @@ def monitor_parent_death(): { "status": WorkerProc.READY_STR, "handle": worker.worker_response_mq.export_handle(), + "peer_response_handles": worker.peer_response_handles, } ) # Ensure message queues are ready. Will deadlock if re-ordered. # Must be kept consistent with the Executor - worker.rpc_broadcast_mq.wait_until_ready() + if worker.rpc_broadcast_mq is not None: + worker.rpc_broadcast_mq.wait_until_ready() worker.worker_response_mq.wait_until_ready() ready_writer.close() ready_writer = None diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 19061fcffdf1..7d6832c57f78 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -188,6 +188,7 @@ def init_device(self): and self.parallel_config.distributed_executor_backend not in ["ray", "external_launcher"] and self.vllm_config.parallel_config.data_parallel_backend != "ray" + and self.vllm_config.parallel_config.nnodes_within_dp == 1 ): # Use local DP rank if available, otherwise use global DP rank. dp_local_rank = self.parallel_config.data_parallel_rank_local @@ -204,7 +205,14 @@ def init_device(self): assert self.local_rank < torch.cuda.device_count(), ( f"DP adjusted local rank {self.local_rank} is out of bounds. " ) - + visible_device_count = ( + torch.cuda.device_count() if torch.cuda.is_available() else 0 + ) + assert self.parallel_config.local_world_size <= visible_device_count, ( + f"local_world_size ({self.parallel_config.local_world_size}) must be " + f"less than or equal to the number of visible devices " + f"({visible_device_count})." + ) self.device = torch.device(f"cuda:{self.local_rank}") current_platform.set_device(self.device) diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index 30ea0ab77bd9..0fc8515c98ef 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -182,6 +182,7 @@ def __init__( self, vllm_config: VllmConfig, rpc_rank: int = 0, + global_rank: int | None = None, ) -> None: """ Initialize the worker wrapper with the given vllm_config and rpc_rank. @@ -194,6 +195,7 @@ def __init__( group. """ self.rpc_rank = rpc_rank + self.global_rank = self.rpc_rank if global_rank is None else global_rank self.worker: WorkerBase | None = None # do not store this `vllm_config`, `init_worker` will set the final @@ -314,7 +316,7 @@ def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None: assert self.worker is not None def initialize_from_config(self, kv_cache_configs: list[Any]) -> None: - kv_cache_config = kv_cache_configs[self.rpc_rank] + kv_cache_config = kv_cache_configs[self.global_rank] with set_current_vllm_config(self.vllm_config): self.worker.initialize_from_config(kv_cache_config) # type: ignore