-
-
Notifications
You must be signed in to change notification settings - Fork 11.9k
[V1] Support MP Executor for multi node distributed inference #23691
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
876bab1
c1fec8a
a748250
3db52e0
058d9d5
2d56d06
1913bd3
8984b7d
bfb3dd7
aad47eb
e868c24
b20664c
55439d6
2d401b0
bd8614b
3733404
0291991
7ceea17
8ca2301
05668ce
2938621
d5463a7
f23904a
acc44db
93e606a
3074132
046cb67
6de8648
3178648
7595a25
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,6 +67,15 @@ class EPLBConfig: | |
| class ParallelConfig: | ||
| """Configuration for the distributed execution.""" | ||
|
|
||
| distributed_master_ip: str = "127.0.0.1" | ||
| """distributed master ip for multi-node distributed | ||
| inference when distributed_executor_backend is mp.""" | ||
| distributed_master_port: int = 0 | ||
| """distributed master port """ | ||
| distributed_node_rank: int = 0 | ||
|
||
| """distributed node rank """ | ||
| distributed_node_size: int = 1 | ||
| """distributed node size """ | ||
| pipeline_parallel_size: int = 1 | ||
| """Number of pipeline parallel groups.""" | ||
| tensor_parallel_size: int = 1 | ||
|
|
@@ -516,9 +525,12 @@ def __post_init__(self) -> None: | |
| ray_found = ray_utils.ray_is_available() | ||
| if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD: | ||
| backend = "uni" | ||
| elif current_platform.is_cuda() and self.distributed_node_size > 1: | ||
| backend = "mp" | ||
| elif ( | ||
| current_platform.is_cuda() | ||
| and cuda_device_count_stateless() < self.world_size | ||
| and self.distributed_node_size == 1 | ||
| ): | ||
| if not ray_found: | ||
| raise ValueError( | ||
|
|
@@ -593,6 +605,11 @@ def _verify_args(self) -> Self: | |
| "Disabled the custom all-reduce kernel because it is not " | ||
| "supported on current platform." | ||
| ) | ||
| if self.distributed_node_size > 1: | ||
| self.disable_custom_all_reduce = True | ||
| logger.debug( | ||
| "Disabled the custom all-reduce since we are running on multi-node." | ||
| ) | ||
| if self.ray_workers_use_nsight and not self.use_ray: | ||
| raise ValueError( | ||
| "Unable to use nsight profiling unless workers run with Ray." | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -594,12 +594,42 @@ def broadcast_object(self, obj=None): | |
| return obj | ||
| return self.dequeue() | ||
|
|
||
| @staticmethod | ||
| def create_from_process_group_single_reader( | ||
luccafong marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| pg: ProcessGroup, | ||
| max_chunk_bytes, | ||
| max_chunks, | ||
| reader_rank: int = 0, | ||
| blocking: bool = False, | ||
| ): | ||
| """Create a message queue from ranks.""" | ||
| # We assume same size acrsso groups | ||
| local_size = torch.cuda.device_count() | ||
| group_rank = dist.get_rank(pg) | ||
| same_node = group_rank // local_size == reader_rank // local_size | ||
| buffer_io = MessageQueue( | ||
| n_reader=1, | ||
| n_local_reader=1 if same_node else 0, | ||
| max_chunk_bytes=max_chunk_bytes, | ||
| max_chunks=max_chunks, | ||
| ) | ||
| handle = buffer_io.export_handle() | ||
| handles = ( | ||
| [None] * dist.get_world_size(pg) if group_rank == reader_rank else None | ||
| ) | ||
| dist.gather_object(handle, handles, dst=reader_rank, group=pg) | ||
| if blocking: | ||
| buffer_io.wait_until_ready() | ||
| return (buffer_io, handles if handles is not None else []) | ||
|
|
||
| @staticmethod | ||
| def create_from_process_group( | ||
| pg: ProcessGroup | StatelessProcessGroup, | ||
| max_chunk_bytes, | ||
| max_chunks, | ||
| writer_rank=0, | ||
| writer_rank: int = 0, | ||
| extra_writer_handler=None, | ||
luccafong marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| blocking: bool = True, | ||
| ) -> "MessageQueue": | ||
| if isinstance(pg, ProcessGroup): | ||
| group_rank = dist.get_rank(pg) | ||
|
|
@@ -609,23 +639,30 @@ def create_from_process_group( | |
| group_rank = pg.rank | ||
| group_world_size = pg.world_size | ||
| global_ranks = list(range(pg.world_size)) | ||
|
|
||
| from vllm.distributed.parallel_state import in_the_same_node_as | ||
|
|
||
| status = in_the_same_node_as(pg, source_rank=writer_rank) | ||
| same_node_ranks = [i for i, s in enumerate(status) if s] | ||
| n_reader = group_world_size - 1 | ||
| n_local_reader = len(same_node_ranks) - 1 | ||
| local_reader_ranks = [i for i in same_node_ranks if i != writer_rank] | ||
| if extra_writer_handler is not None: | ||
| n_reader = group_world_size | ||
| n_local_reader = len(same_node_ranks) | ||
|
||
| buffer_io: MessageQueue | ||
| if group_rank == writer_rank: | ||
| buffer_io = MessageQueue( | ||
| n_reader=n_reader, | ||
| n_local_reader=n_local_reader, | ||
| local_reader_ranks=local_reader_ranks, | ||
| max_chunk_bytes=max_chunk_bytes, | ||
| max_chunks=max_chunks, | ||
| ) | ||
| if extra_writer_handler is not None: | ||
| buffer_io = MessageQueue.create_from_handle( | ||
| extra_writer_handler, group_rank | ||
| ) | ||
| else: | ||
| buffer_io = MessageQueue( | ||
| n_reader=n_reader, | ||
| n_local_reader=n_local_reader, | ||
| local_reader_ranks=local_reader_ranks, | ||
| max_chunk_bytes=max_chunk_bytes, | ||
| max_chunks=max_chunks, | ||
| ) | ||
| handle = buffer_io.export_handle() | ||
| if isinstance(pg, ProcessGroup): | ||
| dist.broadcast_object_list( | ||
|
|
@@ -643,5 +680,6 @@ def create_from_process_group( | |
| else: | ||
| handle = pg.broadcast_obj(None, writer_rank) | ||
| buffer_io = MessageQueue.create_from_handle(handle, group_rank) | ||
| buffer_io.wait_until_ready() | ||
| if blocking: | ||
| buffer_io.wait_until_ready() | ||
| return buffer_io | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -370,6 +370,10 @@ class EngineArgs: | |
| ) = ParallelConfig.distributed_executor_backend | ||
| # number of P/D disaggregation (or other disaggregation) workers | ||
| pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size | ||
| distributed_master_ip: str = ParallelConfig.distributed_master_ip | ||
| distributed_master_port: int = ParallelConfig.distributed_master_port | ||
| distributed_node_size: int = ParallelConfig.distributed_node_size | ||
| distributed_node_rank: int = ParallelConfig.distributed_node_rank | ||
| tensor_parallel_size: int = ParallelConfig.tensor_parallel_size | ||
| decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size | ||
| data_parallel_size: int = ParallelConfig.data_parallel_size | ||
|
|
@@ -720,6 +724,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: | |
| "-pp", | ||
| **parallel_kwargs["pipeline_parallel_size"], | ||
| ) | ||
| parallel_group.add_argument( | ||
| "--distributed-master-ip", **parallel_kwargs["distributed_master_ip"] | ||
|
||
| ) | ||
| parallel_group.add_argument( | ||
| "--distributed-master-port", **parallel_kwargs["distributed_master_port"] | ||
| ) | ||
| parallel_group.add_argument( | ||
| "--distributed-node-size", **parallel_kwargs["distributed_node_size"] | ||
| ) | ||
| parallel_group.add_argument( | ||
| "--distributed-node-rank", **parallel_kwargs["distributed_node_rank"] | ||
| ) | ||
| parallel_group.add_argument( | ||
| "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"] | ||
| ) | ||
|
|
@@ -1471,6 +1487,10 @@ def create_engine_config( | |
| data_parallel_rank=self.data_parallel_rank or 0, | ||
| data_parallel_external_lb=data_parallel_external_lb, | ||
| data_parallel_size_local=data_parallel_size_local, | ||
| distributed_master_ip=self.distributed_master_ip, | ||
| distributed_master_port=self.distributed_master_port, | ||
| distributed_node_size=self.distributed_node_size, | ||
| distributed_node_rank=self.distributed_node_rank, | ||
| data_parallel_master_ip=data_parallel_address, | ||
| data_parallel_rpc_port=data_parallel_rpc_port, | ||
| data_parallel_backend=self.data_parallel_backend, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why remove this one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this file is deleted already