diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 375645fde747..392ede8e8ae9 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -1151,6 +1151,19 @@ steps: - pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py - VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)' +- label: Elastic EP Scaling Test + timeout_in_minutes: 20 + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/ + - vllm/engine/ + - vllm/executor/ + - vllm/compilation/ + - tests/distributed/ + commands: + - pytest -v -s distributed/test_elastic_ep.py + - label: Plugin Tests (2 GPUs) # 40min timeout_in_minutes: 60 mirror_hardwares: [amdexperimental] diff --git a/tests/distributed/test_elastic_ep.py b/tests/distributed/test_elastic_ep.py new file mode 100644 index 000000000000..24320dcfb53f --- /dev/null +++ b/tests/distributed/test_elastic_ep.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import time + +import requests + +from vllm.transformers_utils.tokenizer import get_tokenizer + +from ..utils import RemoteOpenAIServer, _test_completion, multi_gpu_test + +MODEL_NAME = "Qwen/Qwen3-30B-A3B-Thinking-2507-FP8" + + +def _send_scale_command(server: RemoteOpenAIServer, new_dp_size: int) -> bool: + url = server.url_for("scale_elastic_ep") + payload = {"new_data_parallel_size": new_dp_size} + headers = {"Content-Type": "application/json"} + + try: + response = requests.post(url, json=payload, headers=headers, timeout=300) + return response.status_code == 200 + except requests.exceptions.RequestException: + return False + + +@multi_gpu_test(num_gpus=4) +def test_elastic_ep_scaling(): + vllm_serve_args = [ + "--trust-remote-code", + "--disable-log-requests", + "--tensor-parallel-size", + "1", + "--gpu-memory-utilization", + "0.9", + "--max-model-len", + "16384", + "--no-enable-prefix-caching", + "--enable-expert-parallel", + "--all2all-backend", + "pplx", + "--enable-elastic-ep", + "--enable-eplb", + "--eplb-config.num_redundant_experts", + "128", + "--data-parallel-backend", + "ray", + "--data-parallel-size", + "2", + "--data-parallel-size-local", + "2", + "--data-parallel-start-rank", + "0", + ] + + leader_address = os.environ.get("LEADER_ADDRESS") + if leader_address: + vllm_serve_args.extend(["--data-parallel-address", leader_address]) + + tokenizer = get_tokenizer(MODEL_NAME, trust_remote_code=True) + prompt = "Hello, my name is" + token_ids = tokenizer(prompt).input_ids + + # timeout is 20 minutes + with RemoteOpenAIServer( + MODEL_NAME, vllm_serve_args, env_dict={}, max_wait_seconds=1200 + ) as server: + client = server.get_client() + _test_completion(client, MODEL_NAME, prompt, token_ids) + + # Scale up from 2->4 + assert _send_scale_command(server, 4) + time.sleep(10) + _test_completion(client, MODEL_NAME, prompt, token_ids) + + # Scale down from 4->2 + assert _send_scale_command(server, 2) + time.sleep(5) + _test_completion(client, MODEL_NAME, prompt, token_ids) diff --git a/tools/ep_kernels/elastic_ep/install_eep_libraries.sh b/tools/ep_kernels/elastic_ep/install_eep_libraries.sh index 9d7dc1032f5e..25584666f30d 100755 --- a/tools/ep_kernels/elastic_ep/install_eep_libraries.sh +++ b/tools/ep_kernels/elastic_ep/install_eep_libraries.sh @@ -52,6 +52,12 @@ if [ -z "$CUDA_HOME" ]; then exit 1 fi +# assume TORCH_CUDA_ARCH_LIST is set correctly +if [ -z "$TORCH_CUDA_ARCH_LIST" ]; then + echo "TORCH_CUDA_ARCH_LIST is not set, please set it to your desired architecture." + exit 1 +fi + # disable all features except IBGDA export NVSHMEM_IBGDA_SUPPORT=1 @@ -82,5 +88,6 @@ git clone https://github.com/ppl-ai/pplx-kernels cd pplx-kernels # see https://github.com/pypa/pip/issues/9955#issuecomment-838065925 # PIP_NO_BUILD_ISOLATION=0 disables build isolation -PIP_NO_BUILD_ISOLATION=0 TORCH_CUDA_ARCH_LIST=9.0a+PTX pip install . --no-deps -v +git checkout 12cecfd +PIP_NO_BUILD_ISOLATION=0 pip install . --no-deps -v diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index b120c85bf232..77d1ed107a27 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -258,3 +258,34 @@ def _dispatch_to_compiled_code(self): yield finally: self.__class__.forward.__code__ = original + + +def reset_compile_wrapper(model: torch.nn.Module) -> None: + """ + Clean up compiled model and captured CUDA graphs for elastic EP. + """ + if not isinstance(model, TorchCompileWithNoGuardsWrapper) and hasattr( + model, "model" + ): + model = model.model + if not isinstance(model, TorchCompileWithNoGuardsWrapper): + return + # model.do_not_compile is set by the @support_torch_compile decorator + if hasattr(model, "do_not_compile") and model.do_not_compile: + return + from vllm.compilation.counter import compilation_counter + + # reset the compilation counter + compilation_counter.num_models_seen = 0 + compilation_counter.num_graphs_seen = 0 + compilation_counter.num_piecewise_graphs_seen = 0 + compilation_counter.num_piecewise_capturable_graphs_seen = 0 + compilation_counter.num_backend_compilations = 0 + compilation_counter.num_gpu_runner_capture_triggers = 0 + compilation_counter.num_cudagraph_captured = 0 + compilation_counter.num_inductor_compiles = 0 + compilation_counter.num_eager_compiles = 0 + compilation_counter.num_cache_entries_updated = 0 + compilation_counter.num_compiled_artifacts_saved = 0 + compilation_counter.stock_torch_compile_count = 0 + TorchCompileWithNoGuardsWrapper.__init__(model) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 4a8c8bc17cfc..075abd099d05 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -128,6 +128,7 @@ class ParallelConfig: "pplx", "deepep_high_throughput", "deepep_low_latency", + "nixl_ep", "allgather_reducescatter", "flashinfer_all2allv", ] @@ -150,6 +151,9 @@ class ParallelConfig: disable_custom_all_reduce: bool = False """Disable the custom all-reduce kernel and fall back to NCCL.""" + enable_elastic_ep: bool = False + """Enable elastic expert parallelism with stateless NCCL groups for DP/EP.""" + enable_dbo: bool = False """Enable dual batch overlap for the model executor.""" @@ -223,6 +227,29 @@ class is dynamically inherited by the worker class. This is used to inject Set to be private as it's not intended to be configured by users. """ + _stateless_dp_group_port_list: list[list[int]] = Field(default_factory=list) + """List of open ports for stateless DP groups when enable_elastic_ep is True. + Set to be private as it's not intended to be configured by users. + It is a list of list[int], with each inner list contains a set of 3 ports + to be used for setting up the stateless CPU/device/TCPStore groups + in StatelessGroupCoordinator. The number of inner lists is equal to + the number of DP groups, + i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size, + and len(self._stateless_dp_group_port_list[i]) == 3 for all i. + """ + + _stateless_ep_group_port_list: list[list[int]] = Field(default_factory=list) + """List of open ports for stateless EP groups when enable_elastic_ep is True. + Set to be private as it's not intended to be configured by users. + len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size, + """ + + _stateless_world_group_port_list: list[list[int]] = Field(default_factory=list) + """List of open ports for stateless world group when enable_elastic_ep is True. + Set to be private as it's not intended to be configured by users. + len(self._stateless_world_group_port_list) == 1, + """ + decode_context_parallel_size: int = 1 """Number of decode context parallel groups, because the world size does not change by dcp, it simply reuse the GPUs of TP group, and tp_size @@ -342,7 +369,16 @@ def get_next_dp_init_port(self) -> int: return answer - def stateless_init_dp_group(self) -> ProcessGroup: + def get_next_stateless_world_group_port(self) -> list[int]: + return self._stateless_world_group_port_list.pop() + + def get_next_stateless_dp_group_port(self) -> list[int]: + return self._stateless_dp_group_port_list.pop() + + def get_next_stateless_ep_group_port(self) -> list[int]: + return self._stateless_ep_group_port_list.pop() + + def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup: # NOTE: In high-concurrency scenarios multiple processes # can pick the same (currently free) port through a race # condition when calling `get_open_port()`. When the first @@ -366,7 +402,8 @@ def stateless_init_dp_group(self) -> ProcessGroup: self.get_next_dp_init_port(), self.data_parallel_rank, self.data_parallel_size, - backend=current_platform.dist_backend, + backend="gloo", + return_store=return_store, ) except DistNetworkError as e: # We only want to retry when the root cause is EADDRINUSE. @@ -398,6 +435,7 @@ def use_sequence_parallel_moe(self) -> bool: "naive", "deepep_high_throughput", "deepep_low_latency", + "nixl_ep", ) and self.enable_expert_parallel and self.tensor_parallel_size > 1 @@ -511,6 +549,46 @@ def __post_init__(self) -> None: logger.info("Using external launcher for distributed inference.") self.world_size *= self.data_parallel_size + # Initialize stateless group ports for elastic EP + if self.enable_elastic_ep: + if not self.enable_eplb: + raise ValueError("Elastic EP is only supported with enable_eplb=True.") + num_world_groups = 1 + dp_size = self.data_parallel_size + ep_size = self.data_parallel_size * self.world_size_across_dp + num_dp_groups = max(1, self.world_size_across_dp // dp_size) + num_ep_groups = max(1, self.world_size_across_dp // ep_size) + + # NOTE(yongji): + # we need 3 ports for each comm group in `StatelessGroupCoordinator`. + # one for stateless CPU group, one for stateless device group, + # one for stateless TCPStore group. + total_ports_needed = (num_world_groups + num_dp_groups + num_ep_groups) * 3 + if not self._stateless_world_group_port_list: + all_ports = get_open_ports_list(total_ports_needed + 5) + # NOTE(yongji): allocate 5 ports for _data_parallel_master_port_list + # as in the case when elastic EP is not enabled + # (the regular DP code path below this if: `get_open_ports_list(5)`). + # We must set _data_parallel_master_port_list here instead of + # letting the regular DP code path to set it, since + # we should call get_open_ports_list() only once + # to ensure the allocated ports are distinct. + self._data_parallel_master_port_list = all_ports[-5:] + all_ports = all_ports[:-5] + self._stateless_world_group_port_list = [ + all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3) + ] + start_idx = num_world_groups * 3 + self._stateless_dp_group_port_list = [ + all_ports[i : i + 3] + for i in range(start_idx, start_idx + num_dp_groups * 3, 3) + ] + start_idx += num_dp_groups * 3 + self._stateless_ep_group_port_list = [ + all_ports[i : i + 3] + for i in range(start_idx, start_idx + num_ep_groups * 3, 3) + ] + if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: # Data parallel was specified in the engine args. if self.distributed_executor_backend == "external_launcher": diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index c40dde26b741..623a7529a651 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -32,8 +32,8 @@ class NaiveAll2AllManager(All2AllManagerBase): debugging. """ - def __init__(self, cpu_group): - super().__init__(cpu_group) + def __init__(self, cpu_group, tcp_store_group=None): + super().__init__(cpu_group, tcp_store_group) def naive_multicast( self, @@ -105,8 +105,8 @@ class AgRsAll2AllManager(All2AllManagerBase): all-gather (dispatch) and reduce-scatter (combine). """ - def __init__(self, cpu_group): - super().__init__(cpu_group) + def __init__(self, cpu_group, tcp_store_group=None): + super().__init__(cpu_group, tcp_store_group) def dispatch( self, @@ -155,13 +155,16 @@ class PPLXAll2AllManager(All2AllManagerBase): All2All communication based on PPLX kernels. """ - def __init__(self, cpu_group): + def __init__(self, cpu_group, tcp_store_group=None): assert has_pplx(), ( "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md" " to install pplx_kernels." ) - super().__init__(cpu_group) + super().__init__(cpu_group, tcp_store_group) + self.nvshmem_initialized = False + self.handle_cache = Cache() + def get_handle(self, kwargs): if self.internode: # inter-node communication needs nvshmem, # intra-node communication uses p2p mapping directly @@ -181,17 +184,18 @@ def __init__(self, cpu_group): if self.rank == 0 else nvshmem_alloc_empty_unique_id() ) - dist.broadcast( - uid, - src=dist.get_process_group_ranks(self.cpu_group)[0], - group=self.cpu_group, - ) + if self.tcp_store_group is not None: + uid = self.tcp_store_group.broadcast_obj(uid, src=0) + else: + dist.broadcast( + uid, + src=dist.get_process_group_ranks(self.cpu_group)[0], + group=self.cpu_group, + ) logger.debug("PPLX NVSHMEM UID = %s", uid) nvshmem_init(uid, self.rank, self.world_size) + self.nvshmem_initialized = True - self.handle_cache = Cache() - - def get_handle(self, kwargs): import pplx_kernels as pplx # type: ignore[import-not-found] return self.handle_cache.get_or_create( @@ -231,12 +235,12 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): All2All communication based on DeepEP High-Throughput kernels. """ - def __init__(self, cpu_group): + def __init__(self, cpu_group, tcp_store_group=None): assert has_deep_ep(), ( "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md" " to install DeepEP kernels." ) # noqa - super().__init__(cpu_group) + super().__init__(cpu_group, tcp_store_group) self.handle_cache = Cache() # This is the DeepEP default. Stick to it till we can establish @@ -268,8 +272,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): All2All communication based on DeepEP High-Throughput kernels. """ - def __init__(self, cpu_group): - super().__init__(cpu_group) + def __init__(self, cpu_group, tcp_store_group=None): + super().__init__(cpu_group, tcp_store_group) def _make_all2all_kwargs(self) -> dict[Any, Any]: # Defaults for internode and intranode are taken from DeepEP tests. @@ -325,8 +329,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): All2All communication based on DeepEP Low-Latency kernels. """ - def __init__(self, cpu_group): - super().__init__(cpu_group) + def __init__(self, cpu_group, tcp_store_group=None): + super().__init__(cpu_group, tcp_store_group) def _make_all2all_kwargs( self, @@ -386,6 +390,134 @@ def max_sms_used(self) -> int | None: return 0 +class NixlEPAll2AllManager(All2AllManagerBase): + """ + All2All communication based on NIXL EP kernels. + This backend supports elastic EP with dynamic rank connection/disconnection. + """ + + # (nixl_ep_buffer, ep_size) + _buffer: tuple[Any, int] | None = None + + def __init__(self, cpu_group, tcp_store_group=None): + super().__init__(cpu_group, tcp_store_group) + import os + + self.max_num_ep_ranks = envs.VLLM_NIXL_EP_MAX_NUM_RANKS + assert envs.VLLM_NIXL_EP_UCX_IB_DEVICES is not None, ( + "VLLM_NIXL_EP_UCX_IB_DEVICES is not set" + ) + assert envs.VLLM_NIXL_EP_UCX_TCP_DEVICES is not None, ( + "VLLM_NIXL_EP_UCX_TCP_DEVICES is not set" + ) + if envs.VLLM_NIXL_EP_ETCD_ENDPOINTS is not None: + os.environ["NIXL_ETCD_ENDPOINTS"] = envs.VLLM_NIXL_EP_ETCD_ENDPOINTS + if envs.VLLM_NIXL_EP_PLUGIN_DIR is not None: + os.environ["NIXL_PLUGIN_DIR"] = envs.VLLM_NIXL_EP_PLUGIN_DIR + + from vllm.distributed.parallel_state import get_pp_group, get_tp_group + + # NOTE(yongji): envs.LOCAL_RANK may not be set + # an ugly way to get current worker's device index under DPEngineCoreActor + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES.split(",") + assert get_pp_group().world_size == 1 + local_device_index = int(cuda_visible_devices[get_tp_group().rank_in_group]) + ucx_ib_nics = envs.VLLM_NIXL_EP_UCX_IB_DEVICES.split(",") + pxb_ib_nic = ucx_ib_nics[local_device_index] + os.environ["UCX_NET_DEVICES"] = ( + f"cuda0-{pxb_ib_nic}:1" + "," + envs.VLLM_NIXL_EP_UCX_TCP_DEVICES + ) + + def _init_buffer( + self, + max_num_tokens_per_dp_rank: int, + token_hidden_size: int, + num_experts_per_rank: int, + ) -> None: + from nixl_ep import Buffer # type: ignore[import-not-found] + + max_num_global_experts = self.max_num_ep_ranks * num_experts_per_rank + num_rdma_bytes = Buffer.get_rdma_size_hint( + num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank, + hidden=token_hidden_size, + num_ranks=self.max_num_ep_ranks, + num_experts=max_num_global_experts, + ) + assert NixlEPAll2AllManager._buffer is None, ( + "NIXL EP buffer already initialized" + ) + buffer = Buffer( + nvlink_backend="nixl", + explicitly_destroy=True, + rank=self.rank, + enable_shrink=True, + ) + buffer.update_memory_buffers( + num_ranks=self.max_num_ep_ranks, + num_experts_per_rank=num_experts_per_rank, + num_rdma_bytes=num_rdma_bytes, + ) + ranks_to_connect = list(range(self.cpu_group.size())) + buffer.connect_ranks(ranks_to_connect) + NixlEPAll2AllManager._buffer = (buffer, self.cpu_group.size()) + + def _update_buffer(self): + assert NixlEPAll2AllManager._buffer is not None + buffer, current_ep_size = NixlEPAll2AllManager._buffer + current_ranks = list(range(current_ep_size)) + new_ep_size = self.cpu_group.size() + if new_ep_size > len(current_ranks): + ranks_to_connect = list(range(len(current_ranks), new_ep_size)) + buffer.connect_ranks(ranks_to_connect) + else: + ranks_to_disconnect = current_ranks[new_ep_size:] + buffer.disconnect_ranks(ranks_to_disconnect) + + def get_handle(self, kwargs): + if ( + NixlEPAll2AllManager._buffer is not None + and NixlEPAll2AllManager._buffer[1] == self.cpu_group.size() + ): + return NixlEPAll2AllManager._buffer[0] + + num_experts_per_rank = kwargs["num_global_experts"] // kwargs["num_ep_ranks"] + nixl_kwargs = dict( + max_num_tokens_per_dp_rank=kwargs["max_num_tokens_per_dp_rank"], + token_hidden_size=kwargs["token_hidden_size"], + num_experts_per_rank=num_experts_per_rank, + ) + if NixlEPAll2AllManager._buffer is None: + self._init_buffer(**nixl_kwargs) + else: + self._update_buffer() + + assert NixlEPAll2AllManager._buffer is not None + handle = NixlEPAll2AllManager._buffer[0] + return handle + + def dispatch( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + def combine( + self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False + ) -> torch.Tensor: + raise NotImplementedError + + def destroy(self): + # NOTE(yongji): NIXLEPAll2AllManager instance is recreated during + # scale-up/down, so we cannot destroy the persistent buffer here. + pass + + # NIXL EP uses RDMA so no SMs are used for communication + def max_sms_used(self) -> int | None: + return 0 + + class FlashInferAllToAllManager(All2AllManagerBase): """ All2All communication based on flashinfer kernels. @@ -396,11 +528,11 @@ class FlashInferAllToAllManager(All2AllManagerBase): rank: int world_size: int - def __init__(self, cpu_group): + def __init__(self, cpu_group, tcp_store_group=None): assert has_flashinfer_all2all(), ( "flashinfer all2all module not found. Please install/check flashinfer" ) # noqa - super().__init__(cpu_group) + super().__init__(cpu_group, tcp_store_group) logger.debug( "Initialize for flashinfer All2All rank=%d, world size=%d", self.rank, diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 3a849da70e4c..63de1fa88a27 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -29,8 +29,9 @@ class All2AllManagerBase: rank: int world_size: int - def __init__(self, cpu_group): + def __init__(self, cpu_group, tcp_store_group=None): self.cpu_group = cpu_group + self.tcp_store_group = tcp_store_group # compute some common properties from vllm.distributed.parallel_state import ( @@ -47,12 +48,17 @@ def __init__(self, cpu_group): # when we create this object self.dp_rank = self.dp_group.rank_in_group self.dp_world_size = self.dp_group.world_size - self.rank = dist.get_rank(cpu_group) - self.world_size = dist.get_world_size(cpu_group) + self.rank = cpu_group.rank() + self.world_size = cpu_group.size() # all2all communication often has separate implementations for # intra-node and inter-node communication - self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0)) + if tcp_store_group is None: + self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0)) + else: + self.internode = not all( + in_the_same_node_as(tcp_store_group, source_rank=0) + ) def get_handle(self, kwargs): # get a handle for the all2all communication, @@ -98,17 +104,36 @@ def __init__( device: torch.device | None = None, device_group: ProcessGroup | None = None, unique_name: str = "", + global_ranks: list[int] | None = None, + global_world_size: int | None = None, ): self.device = device or torch.device("cpu") self.cpu_group = cpu_group self.device_group = device_group self.unique_name = unique_name - self.rank = dist.get_rank(cpu_group) - self.world_size = dist.get_world_size(cpu_group) - self.ranks = dist.get_process_group_ranks(cpu_group) - self.global_rank = dist.get_rank() - self.global_world_size = dist.get_world_size() - self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank) + + # Check if this is a stateless process group + from torch.distributed.distributed_c10d import _world + + is_stateless = _world.pg_map.get(cpu_group, None) is None + + if is_stateless: + # For stateless groups, we can't use torch.distributed methods + self.rank = cpu_group.rank() + self.world_size = cpu_group.size() + assert global_ranks is not None + assert global_world_size is not None + self.ranks = global_ranks + self.global_rank = self.ranks[self.rank] + self.global_world_size = global_world_size + self.rank_in_group = self.rank + else: + self.rank = dist.get_rank(cpu_group) + self.world_size = dist.get_world_size(cpu_group) + self.ranks = dist.get_process_group_ranks(cpu_group) + self.global_rank = dist.get_rank() + self.global_world_size = dist.get_world_size() + self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank) use_ep = False all2all_backend = None @@ -252,6 +277,13 @@ def recv( torch.distributed.recv(tensor, self.ranks[src], self.device_group) return tensor + def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor: + """Broadcast a tensor from source rank to all ranks.""" + if self.world_size == 1: + return tensor + torch.distributed.broadcast(tensor, self.ranks[src], self.device_group) + return tensor + def destroy(self): pass @@ -295,3 +327,6 @@ def combine( This is a no-op in the base class. """ return hidden_states + + def batch_isend_irecv(self, p2p_ops: list): + raise NotImplementedError diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 2e878eef908a..26a31d534028 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -16,6 +16,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform +from ..utils import StatelessProcessGroup from .base_device_communicator import DeviceCommunicatorBase logger = init_logger(__name__) @@ -28,8 +29,18 @@ def __init__( device: torch.device | None = None, device_group: ProcessGroup | None = None, unique_name: str = "", + global_ranks: list[int] | None = None, + global_world_size: int | None = None, + tcp_store_group: StatelessProcessGroup | None = None, ): - super().__init__(cpu_group, device, device_group, unique_name) + super().__init__( + cpu_group, + device, + device_group, + unique_name, + global_ranks, + global_world_size, + ) if "tp" not in unique_name: # custom allreduce or torch symm mem can be used only by tp use_custom_allreduce = False @@ -56,7 +67,7 @@ def __init__( self.pynccl_comm: PyNcclCommunicator | None = None if self.world_size > 1: self.pynccl_comm = PyNcclCommunicator( - group=self.cpu_group, + group=self.cpu_group if tcp_store_group is None else tcp_store_group, device=self.device, ) if is_symmetric_memory_enabled(): @@ -93,27 +104,45 @@ def __init__( if self.all2all_backend == "naive": from .all2all import NaiveAll2AllManager - self.all2all_manager = NaiveAll2AllManager(self.cpu_group) + self.all2all_manager = NaiveAll2AllManager( + self.cpu_group, tcp_store_group + ) elif self.all2all_backend == "allgather_reducescatter": from .all2all import AgRsAll2AllManager - self.all2all_manager = AgRsAll2AllManager(self.cpu_group) + self.all2all_manager = AgRsAll2AllManager( + self.cpu_group, tcp_store_group + ) elif self.all2all_backend == "pplx": from .all2all import PPLXAll2AllManager - self.all2all_manager = PPLXAll2AllManager(self.cpu_group) + self.all2all_manager = PPLXAll2AllManager( + self.cpu_group, tcp_store_group + ) elif self.all2all_backend == "deepep_high_throughput": from .all2all import DeepEPHTAll2AllManager - self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group) + self.all2all_manager = DeepEPHTAll2AllManager( + self.cpu_group, tcp_store_group + ) elif self.all2all_backend == "deepep_low_latency": from .all2all import DeepEPLLAll2AllManager - self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group) + self.all2all_manager = DeepEPLLAll2AllManager( + self.cpu_group, tcp_store_group + ) + elif self.all2all_backend == "nixl_ep": + from .all2all import NixlEPAll2AllManager + + self.all2all_manager = NixlEPAll2AllManager( + self.cpu_group, tcp_store_group + ) elif self.all2all_backend == "flashinfer_all2allv": from .all2all import FlashInferAllToAllManager - self.all2all_manager = FlashInferAllToAllManager(self.cpu_group) + self.all2all_manager = FlashInferAllToAllManager( + self.cpu_group, tcp_store_group + ) else: raise ValueError(f"Unknown all2all backend: {self.all2all_backend}") @@ -261,6 +290,18 @@ def recv( torch.distributed.recv(tensor, self.ranks[src], self.device_group) return tensor + def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor: + """Broadcast a tensor from source rank to all ranks.""" + if self.world_size == 1: + return tensor + + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.broadcast(tensor, src) + return tensor + else: + raise ValueError("No PyNCCL communicator found") + def destroy(self): if self.pynccl_comm is not None: self.pynccl_comm = None @@ -338,3 +379,10 @@ def combine( hidden_states, is_sequence_parallel ) return hidden_states + + def batch_isend_irecv(self, p2p_ops: list): + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.batch_isend_irecv(p2p_ops) + else: + raise ValueError("No PyNCCL communicator found") diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 2fc35e80f591..44dc113e4f55 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -312,10 +312,19 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None): ) if stream is None: stream = current_stream() + if tensor.dtype in [ + torch.float8_e5m2, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, + ]: + nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8) + else: + nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype) self.nccl.ncclSend( buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), + nccl_dtype, dst, self.comm, cudaStream_t(stream.cuda_stream), @@ -330,10 +339,19 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None): ) if stream is None: stream = current_stream() + if tensor.dtype in [ + torch.float8_e5m2, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, + ]: + nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8) + else: + nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype) self.nccl.ncclRecv( buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), + nccl_dtype, src, self.comm, cudaStream_t(stream.cuda_stream), @@ -384,3 +402,17 @@ def register_comm_window_raw(self, ptr: int, size: int): def deregister_comm_window(self, window): return self.nccl.ncclCommWindowDeregister(self.comm, window) + + def batch_isend_irecv(self, p2p_ops: list, stream=None): + if self.disabled: + return + if stream is None: + stream = current_stream() + self.group_start() + for op in p2p_ops: + if op.op is torch.distributed.isend: + self.send(op.tensor, op.group_peer, stream) + elif op.op is torch.distributed.irecv: + self.recv(op.tensor, op.group_peer, stream) + + self.group_end() diff --git a/vllm/distributed/elastic_ep/__init__.py b/vllm/distributed/elastic_ep/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/distributed/elastic_ep/elastic_execute.py b/vllm/distributed/elastic_ep/elastic_execute.py new file mode 100644 index 000000000000..550f35b3b2dd --- /dev/null +++ b/vllm/distributed/elastic_ep/elastic_execute.py @@ -0,0 +1,490 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc +import weakref +from collections.abc import Iterable, Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed import P2POp + +from vllm.compilation.counter import compilation_counter +from vllm.compilation.cuda_graph import CUDAGraphWrapper +from vllm.compilation.wrapper import reset_compile_wrapper +from vllm.config import ( + CompilationMode, + set_current_vllm_config, +) +from vllm.distributed import ( + get_dp_group, + get_ep_group, + get_pcp_group, + get_standby_dp_group, + get_standby_ep_group, + get_tp_group, +) +from vllm.distributed.parallel_state import ( + create_standby_groups, + prepare_communication_buffer_for_model, + switch_to_standby_groups, +) +from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig +from vllm.utils.torch_utils import supports_dynamo +from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType +from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper + +logger = init_logger(__name__) + + +def batch_transfer_weights( + model: nn.Module, + is_sender: bool, + peer_rank: int, + dp_group: StatelessGroupCoordinator, + expert_weights: Sequence[Iterable[torch.Tensor]], +) -> None: + device_comm = dp_group.device_communicator + if device_comm is None: + raise ValueError("No device communicator found") + + expert_weights_set = set() + for weight_group in expert_weights: + for weight in weight_group: + expert_weights_set.add(weight.data_ptr()) + + state_dict = model.state_dict() + all_params = [] + + for name, param in state_dict.items(): + if name.endswith("expert_map"): + continue + if param.data_ptr() not in expert_weights_set: + all_params.append(param.data) + + assert len(all_params) > 0 + p2p_ops = [] + for param in all_params: + op = object.__new__(P2POp) + if is_sender: + op.op = torch.distributed.isend + op.tensor = param + else: + op.op = torch.distributed.irecv + op.tensor = param + op.group_peer = peer_rank + p2p_ops.append(op) + device_comm.batch_isend_irecv(p2p_ops) + + +def broadcast_expert_mapping( + physical_to_logical: torch.Tensor | None, + num_local_physical_experts: int | None, + num_logical_experts: int | None, + dp_group: StatelessGroupCoordinator, + device: torch.device, + src_rank: int = 0, +) -> tuple[torch.Tensor, int, int]: + if dp_group.rank_in_group == src_rank: + assert physical_to_logical is not None + assert num_local_physical_experts is not None + assert num_logical_experts is not None + assert physical_to_logical.dtype == torch.int64 + shape_tensor = torch.tensor( + list(physical_to_logical.shape), dtype=torch.int64, device="cpu" + ) + metadata_tensor = torch.tensor( + [num_local_physical_experts, num_logical_experts], + dtype=torch.int64, + device="cpu", + ) + else: + shape_tensor = torch.empty(2, dtype=torch.int64, device="cpu") + metadata_tensor = torch.empty(2, dtype=torch.int64, device="cpu") + + shape_tensor = dp_group.tcp_store_group.broadcast(shape_tensor, src_rank) + metadata_tensor = dp_group.tcp_store_group.broadcast(metadata_tensor, src_rank) + + if dp_group.rank_in_group != src_rank: + assert device is not None + physical_to_logical = torch.empty( + tuple(shape_tensor.tolist()), + dtype=torch.int64, + device=device, + ) + + assert physical_to_logical is not None + physical_to_logical = dp_group.broadcast(physical_to_logical, src_rank) + num_local_physical_experts = int(metadata_tensor[0].item()) + num_logical_experts = int(metadata_tensor[1].item()) + + return physical_to_logical, num_local_physical_experts, num_logical_experts + + +class ElasticEPScalingExecutor: + def __init__(self, worker): + self.worker_ref = weakref.ref(worker) + self.reconfig_request = None + + @property + def worker(self): + worker = self.worker_ref() + if worker is None: + raise RuntimeError("Worker has been garbage collected") + return worker + + def execute(self, execute_method: str, *args, **kwargs): + method = getattr(self, execute_method, None) + if method is None: + raise ValueError(f"Unknown execute method: {execute_method}") + return method(*args, **kwargs) + + def create_standby_groups( + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: + self.reconfig_request = reconfig_request + new_dp_size = reconfig_request.new_data_parallel_size + world_size = self.worker.vllm_config.parallel_config.world_size + new_world_size_across_dp = world_size * new_dp_size + # TODO(yongji): check whether we need to use updated vllm_config here + with set_current_vllm_config(self.worker.vllm_config): + create_standby_groups( + new_dp_size=new_dp_size, + new_world_size_across_dp=new_world_size_across_dp, + master_ip=reconfig_request.new_data_parallel_master_ip, + world_group_ports=reconfig_request.new_stateless_world_group_port_list, + dp_group_ports=reconfig_request.new_stateless_dp_group_port_list, + ep_group_ports=reconfig_request.new_stateless_ep_group_port_list, + ) + self.worker.model_runner.eplb_disabled = True + standby_ep_group = get_standby_ep_group() + assert standby_ep_group is not None + if standby_ep_group.rank == 0: + logger.info("[Elastic EP] EPLB disabled during elastic scaling transition") + + def transfer_weights(self, old_dp_size: int, new_dp_size: int) -> None: + standby_dp_group = get_standby_dp_group() + assert standby_dp_group is not None + # Broadcast old_dp_size to all workers in standby group + if standby_dp_group.rank_in_group < old_dp_size: + old_dp_size_tensor = torch.tensor( + [old_dp_size], dtype=torch.int64, device="cpu" + ) + else: + old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu") + old_dp_size_tensor = standby_dp_group.tcp_store_group.broadcast( + old_dp_size_tensor, 0 + ) + + num_new_workers = new_dp_size - old_dp_size + dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank + + ranks_to_send = [] + # NOTE(yongji): determine sender-receiver pairing in weight transfer. + # Mapping rule: + # Base: each existing worker i gets (num_new_workers // old_dp_size) new workers + # to send weights to. Worker i sends weights to new workers with global ranks + # in [old_dp_size + i * num_dst_per_sender, + # old_dp_size + (i + 1) * num_dst_per_sender]. + # Remainder: Each of the first (num_new_workers % old_dp_size) existing workers + # gets an additional new worker to send weights to, whose global rank is + # old_dp_size * (num_dst_per_sender + 1) + i. + num_dst_per_sender = num_new_workers // old_dp_size + sender_pos = dp_rank + recv_begin = sender_pos * num_dst_per_sender + recv_end = recv_begin + num_dst_per_sender + ranks_to_send = list(range(old_dp_size + recv_begin, old_dp_size + recv_end)) + remainder_start = old_dp_size * num_dst_per_sender + recver_pos = remainder_start + sender_pos + if recver_pos < num_new_workers: + ranks_to_send.append(old_dp_size + recver_pos) + + model = self.worker.model_runner.get_model() + for new_worker_rank in sorted(ranks_to_send): + batch_transfer_weights( + model=model, + is_sender=True, + peer_rank=new_worker_rank, + dp_group=standby_dp_group, + expert_weights=model.expert_weights, + ) + torch.cuda.synchronize() + + def broadcast_expert_mapping(self) -> None: + standby_dp_group = get_standby_dp_group() + assert standby_dp_group is not None + model_config = self.worker.model_runner.model_config + eplb_state = self.worker.model_runner.eplb_state + assert eplb_state is not None + eplb_model_state = eplb_state.model_states[model_config.compute_hash()] + physical_to_logical = eplb_model_state.physical_to_logical_map + num_physical_experts = physical_to_logical.shape[1] + num_local_physical_experts = num_physical_experts // get_ep_group().world_size + num_logical_experts = eplb_model_state.logical_replica_count.shape[1] + broadcast_expert_mapping( + physical_to_logical=physical_to_logical, + num_local_physical_experts=num_local_physical_experts, + num_logical_experts=num_logical_experts, + dp_group=standby_dp_group, + src_rank=0, + device=self.worker.device, + ) + + def switch_and_prepare(self) -> None: + old_dp_size = get_dp_group().world_size + old_ep_size = get_ep_group().world_size + + switch_to_standby_groups() + + parallel_config = self.worker.vllm_config.parallel_config + reconfig_request = self.reconfig_request + assert reconfig_request is not None + new_dp_size = reconfig_request.new_data_parallel_size + new_ep_size = get_ep_group().world_size + + parallel_config.data_parallel_size = new_dp_size + if ( + reconfig_request.new_data_parallel_rank + != ReconfigureRankType.KEEP_CURRENT_RANK + ): + parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank + if ( + reconfig_request.new_data_parallel_rank_local + != ReconfigureRankType.KEEP_CURRENT_RANK + ): + parallel_config.data_parallel_rank_local = ( + reconfig_request.new_data_parallel_rank_local + ) + parallel_config.data_parallel_master_ip = ( + reconfig_request.new_data_parallel_master_ip + ) + parallel_config.data_parallel_master_port = ( + reconfig_request.new_data_parallel_master_port + ) + + # Reconfigure MoE modules with new EP size + moe_modules = [ + module + for module in self.worker.model_runner.model.modules() + if ( + module.__class__.__name__ == "FusedMoE" + or module.__class__.__name__ == "SharedFusedMoE" + ) + ] + num_local_experts = moe_modules[0].moe_config.num_local_experts + assert all( + module.moe_config.num_local_experts == num_local_experts + for module in moe_modules + ), "All MoE modules must have the same number of experts" + for module in moe_modules: + module.moe_config.num_experts = num_local_experts * new_ep_size + module.global_num_experts = module.moe_config.num_experts + module.moe_parallel_config = FusedMoEParallelConfig.make( + tp_size_=get_tp_group().world_size, + pcp_size_=get_pcp_group().world_size, + dp_size_=get_dp_group().world_size, + vllm_parallel_config=parallel_config, + ) + module.moe_config.moe_parallel_config = module.moe_parallel_config + + # Update EPLB state + eplb_state = self.worker.model_runner.eplb_state + assert eplb_state is not None + model_config = self.worker.model_runner.model_config + eplb_model_state = eplb_state.model_states[model_config.compute_hash()] + + num_physical_experts = num_local_experts * new_ep_size + num_logical_experts = eplb_model_state.logical_replica_count.shape[1] + parallel_config.eplb_config.num_redundant_experts = ( + num_physical_experts - num_logical_experts + ) + old_physical_to_logical = eplb_model_state.physical_to_logical_map + num_moe_layers = old_physical_to_logical.shape[0] + num_local_experts = eplb_model_state.expert_load_pass.shape[1] // old_ep_size + if new_dp_size > old_dp_size: + expanded_physical_to_logical = torch.full( + (num_moe_layers, num_local_experts * new_ep_size), + -1, + dtype=old_physical_to_logical.dtype, + device=old_physical_to_logical.device, + ) + expanded_physical_to_logical[:, : num_local_experts * old_ep_size] = ( + old_physical_to_logical + ) + eplb_model_state.physical_to_logical_map = expanded_physical_to_logical + + old_num_physical_experts = eplb_model_state.expert_load_pass.shape[1] + pad_size = num_physical_experts - old_num_physical_experts + if new_dp_size > old_dp_size: + assert pad_size > 0 + expanded_expert_load_pass = F.pad( + eplb_model_state.expert_load_pass, (0, pad_size), value=0 + ) + expanded_expert_load_window = F.pad( + eplb_model_state.expert_load_window, (0, pad_size), value=0 + ) + eplb_model_state.expert_load_pass = expanded_expert_load_pass + eplb_model_state.expert_load_window = expanded_expert_load_window + eplb_state.num_valid_physical_experts = old_num_physical_experts + else: + assert pad_size < 0 + eplb_model_state.expert_load_pass = eplb_model_state.expert_load_pass[ + :, :num_physical_experts + ] + eplb_model_state.expert_load_window = eplb_model_state.expert_load_window[ + :, :, :num_physical_experts + ] + eplb_state.num_valid_physical_experts = num_physical_experts + + model = self.worker.model_runner.get_model() + model.expert_weights = [] + with set_current_vllm_config(self.worker.vllm_config): + model.set_eplb_state( + eplb_model_state.expert_load_pass, + eplb_model_state.logical_to_physical_map, + eplb_model_state.logical_replica_count, + ) + model.update_physical_experts_metadata( + num_physical_experts=num_physical_experts, + num_local_physical_experts=num_local_experts, + ) + prepare_communication_buffer_for_model(self.worker.model_runner.model) + if ( + self.worker.vllm_config.compilation_config.mode + == CompilationMode.STOCK_TORCH_COMPILE + and supports_dynamo() + ): + # NOTE(yongji): when using stock torch.compile, + # torch.compile is triggered during GPUModelRunner's load_model() + # TODO(yongji):check do we need to re-trigger torch.compile here? + # any changes to the tensor shapes in execution should already + # be handled internally by torch.compile. + backend = self.worker.vllm_config.compilation_config.init_backend( + self.worker.vllm_config + ) + compilation_counter.stock_torch_compile_count += 1 + self.worker.model_runner.model.compile(fullgraph=True, backend=backend) + + # release all previously captured CUDA graphs + if isinstance(self.worker.model_runner.model, CUDAGraphWrapper): + # TODO(yongji): do we need to reset graph pool here? + wrapper = self.worker.model_runner.model + wrapper.concrete_cudagraph_entries = {} + elif isinstance(self.worker.model_runner.model, UBatchWrapper): + raise RuntimeError("DBO is not yet supported in elastic EP") + + # reset the compile wrapper + with set_current_vllm_config(self.worker.vllm_config): + reset_compile_wrapper(self.worker.model_runner.get_model()) + + gc.collect() + torch.cuda.empty_cache() + self.worker.compile_or_warm_up_model() + + def perform_eplb_reshuffle(self, new_dp_size: int | None = None) -> None: + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Starting expert resharding...") + + eplb_state = self.worker.model_runner.eplb_state + assert eplb_state is not None + + model_config = self.worker.model_runner.model_config + eplb_model_state = eplb_state.model_states[model_config.compute_hash()] + is_async_enabled = eplb_model_state.is_async_enabled + eplb_model_state.is_async_enabled = False + if new_dp_size is None: + eplb_state.rearrange() + else: + # scale down + parallel_config = self.worker.vllm_config.parallel_config + tp_size = parallel_config.tensor_parallel_size + old_ep_size = parallel_config.data_parallel_size * tp_size + new_ep_size = new_dp_size * tp_size + + rank_mapping = { + old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1 + for old_ep_rank in range(old_ep_size) + } + + eplb_state.rearrange(rank_mapping=rank_mapping) + # NOTE(yongji): check whether we need to synchronize here + torch.cuda.synchronize() + # reset expert_rearrangement_step to ensure all ranks are synchronized + eplb_state.expert_rearrangement_step = 0 + eplb_model_state.is_async_enabled = is_async_enabled + self.worker.model_runner.eplb_disabled = False + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Expert resharding completed") + + def receive_weights(self) -> None: + dp_group = get_dp_group() + assert isinstance(dp_group, StatelessGroupCoordinator) + new_dp_size = dp_group.world_size + dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank + + # Receive old_dp_size broadcasted during transfer_weights + old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu") + old_dp_size_tensor = dp_group.tcp_store_group.broadcast(old_dp_size_tensor, 0) + old_dp_size = int(old_dp_size_tensor[0].item()) + + # Calculate which existing worker will send to this new worker + num_new_workers = new_dp_size - old_dp_size + new_worker_idx = dp_rank - old_dp_size + num_dst_per_sender = num_new_workers // old_dp_size + remainder = num_new_workers % old_dp_size + + if new_worker_idx < remainder * (num_dst_per_sender + 1): + sender_rank = new_worker_idx // (num_dst_per_sender + 1) + else: + sender_rank = ( + remainder + + (new_worker_idx - remainder * (num_dst_per_sender + 1)) + // num_dst_per_sender + ) + + model = self.worker.model_runner.get_model() + batch_transfer_weights( + model=model, + is_sender=False, + peer_rank=sender_rank, + dp_group=dp_group, + expert_weights=model.expert_weights, + ) + torch.cuda.synchronize() + + def receive_expert_mapping(self) -> tuple[torch.Tensor, int, int]: + dp_group = get_dp_group() + assert isinstance(dp_group, StatelessGroupCoordinator) + physical_to_logical, num_local_physical_experts, num_logical_experts = ( + broadcast_expert_mapping( + physical_to_logical=None, + num_local_physical_experts=None, + num_logical_experts=None, + dp_group=dp_group, + src_rank=0, + device=self.worker.device, + ) + ) + num_moe_layers = physical_to_logical.shape[0] + new_dp_size = get_dp_group().world_size + tp_size = self.worker.vllm_config.parallel_config.tensor_parallel_size + new_ep_size = new_dp_size * tp_size + expanded_physical_to_logical = torch.full( + (num_moe_layers, num_local_physical_experts * new_ep_size), + -1, + dtype=physical_to_logical.dtype, + device=physical_to_logical.device, + ) + old_num_physical_experts = physical_to_logical.shape[1] + expanded_physical_to_logical[:, :old_num_physical_experts] = physical_to_logical + return ( + expanded_physical_to_logical, + num_logical_experts, + old_num_physical_experts, + ) + + def prepare_new_worker(self) -> None: + with set_current_vllm_config(self.worker.vllm_config): + prepare_communication_buffer_for_model(self.worker.model_runner.get_model()) diff --git a/vllm/distributed/elastic_ep/elastic_state.py b/vllm/distributed/elastic_ep/elastic_state.py new file mode 100644 index 000000000000..1874952a38c5 --- /dev/null +++ b/vllm/distributed/elastic_ep/elastic_state.py @@ -0,0 +1,521 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import enum +import time +import weakref +from datetime import timedelta +from typing import TYPE_CHECKING, Literal + +import torch.distributed + +from vllm.config import ParallelConfig +from vllm.distributed import ( + sched_yield, + stateless_destroy_torch_distributed_process_group, +) +from vllm.logger import init_logger +from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType +from vllm.v1.engine.core import DPEngineCoreProc + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.v1.executor.abstract import Executor + +logger = init_logger(__name__) + +WorkerType = Literal["existing", "new", "removing"] + + +class ScaleUpExistingEningeState(enum.IntEnum): + WAIT_NEW_WORKERS_INIT = 0 + CREATE_STANDBY_GROUPS = 1 + TRANSFER_EXPERT_MAPPING = 2 + WAIT_NEW_WORKERS_WEIGHTS_INIT = 3 + TRANSFER_WEIGHTS = 4 + SYNC_KV_CACHE_MEMORY_SIZE = 5 + SWITCH_AND_PREPARE = 6 + EPLB_RESHUFFLE = 7 + COMPLETE = 8 + + +class ScaleUpNewEngineState(enum.IntEnum): + PREPARE = 0 + EPLB_RESHUFFLE = 1 + COMPLETE = 2 + + +class ScaleDownRemainingEngineState(enum.IntEnum): + PREPARE = 0 + EPLB_RESHUFFLE = 1 + SWITCH_AND_PREPARE = 2 + COMPLETE = 3 + + +class ScaleDownRemovingEngineState(enum.IntEnum): + PREPARE = 0 + EPLB_RESHUFFLE = 1 + COMPLETE = 2 + + +class _BarrierTimeoutError(RuntimeError): + """ + Exception raised for timeout + in the first stage of our two-staged + TCPStore based barrier to synchronize the + execution of all engines in the DP group. + """ + + +class ElasticEPScalingState: + def __init__( + self, + model_executor: "Executor", + engine_core: "DPEngineCoreProc", + vllm_config: "VllmConfig", + new_parallel_config: ParallelConfig, + worker_type: WorkerType, + scale_type: Literal["scale_up", "scale_down"], + reconfig_request: ReconfigureDistributedRequest | None = None, + ): + self.model_executor_ref = weakref.ref(model_executor) + self.engine_core_ref = weakref.ref(engine_core) + self.vllm_config = vllm_config + self.old_dp_group = self.engine_core.dp_group if worker_type != "new" else None + self.old_dp_store = self.engine_core.dp_store if worker_type != "new" else None + self.new_dp_group_or_config: torch.distributed.ProcessGroup | ParallelConfig = ( + self.engine_core.dp_group if worker_type == "new" else new_parallel_config + ) + self.new_dp_store = self.engine_core.dp_store if worker_type == "new" else None + self.worker_type = worker_type + self.scale_type = scale_type + self.reconfig_request = reconfig_request + self.last_barrier_timeout = False + + if scale_type == "scale_up": + self.state = ( + ScaleUpNewEngineState.PREPARE + if worker_type == "new" + else ScaleUpExistingEningeState.WAIT_NEW_WORKERS_INIT + ) + else: + self.state = ( + ScaleDownRemovingEngineState.PREPARE + if worker_type == "removing" + else ScaleDownRemainingEngineState.PREPARE + ) + + @property + def model_executor(self) -> "Executor": + model_executor = self.model_executor_ref() + if model_executor is None: + raise RuntimeError("Model executor has been garbage collected") + return model_executor + + @property + def engine_core(self) -> "DPEngineCoreProc": + engine_core = self.engine_core_ref() + if engine_core is None: + raise RuntimeError("Engine core has been garbage collected") + return engine_core + + def progress(self) -> bool: + if self.scale_type == "scale_up": + return ( + self._progress_new_engine() + if self.worker_type == "new" + else self._progress_existing_engine() + ) + return ( + self._progress_removing_engine() + if self.worker_type == "removing" + else self._progress_remaining_engine() + ) + + def _execute_tcp_store_barrier( + self, dp_store, group_rank, group_size, barrier_id, timeout=None + ): + arrival_key = f"arrival_{barrier_id}_{group_rank}" + dp_store.set(arrival_key, b"1") + + start_time = time.time() + processes_arrived: set[int] = set() + + while len(processes_arrived) < group_size: + if ( + timeout is not None + and time.time() - start_time > timeout.total_seconds() + ): + raise _BarrierTimeoutError( + f"Barrier timed out after {timeout.total_seconds()} seconds" + ) + + for i in range(group_size): + if i in processes_arrived: + continue + + key = f"arrival_{barrier_id}_{i}" + present = dp_store.check([key]) + if present: + processes_arrived.add(i) + + if len(processes_arrived) < group_size: + sched_yield() + + def _staged_barrier(self, use_new_group: bool) -> bool: + # NOTE(yongji): currently we use a two-staged + dp_store = self.new_dp_store if use_new_group else self.old_dp_store + dp_group = self.new_dp_group_or_config if use_new_group else self.old_dp_group + + group_rank = dp_group.rank() + group_size = dp_group.size() + barrier_id = "eep_barrier" + + # TODO(yongji): figure out appropriate timeout for the barrier + timeout = ( + None + if dp_store.check(["eep_barrier_sync"]) or self.last_barrier_timeout + else timedelta(seconds=5) + ) + + try: + self._execute_tcp_store_barrier( + dp_store, group_rank, group_size, barrier_id, timeout=timeout + ) + torch.distributed.barrier(dp_group) + self.last_barrier_timeout = False + # clean up barrier keys + if group_rank == 0: + dp_store.delete_key("eep_barrier_sync") + for i in range(group_size): + dp_store.delete_key(f"arrival_{barrier_id}_{i}") + return True + except _BarrierTimeoutError as e: + if timeout is None: + raise RuntimeError("Unexpected timeout encountered") from e + self.last_barrier_timeout = True + dp_store.compare_set("eep_barrier_sync", "", b"1") + return False + + def _progress_existing_engine(self) -> bool: + state = self.state + + if state == ScaleUpExistingEningeState.WAIT_NEW_WORKERS_INIT: + return False + + elif state == ScaleUpExistingEningeState.CREATE_STANDBY_GROUPS: + # NOTE(yongji): wait for all existing workers to receive the request + if ( + int(self.old_dp_store.get("eep_barrier_engine_count")) + < self.old_dp_group.size() + ): + return False + if not self._staged_barrier(use_new_group=False): + return False + if self.old_dp_group.rank() == 0: + self.old_dp_store.delete_key("eep_barrier_engine_count") + self._create_standby_groups() + self.state = ScaleUpExistingEningeState.TRANSFER_EXPERT_MAPPING + return True + + elif state == ScaleUpExistingEningeState.TRANSFER_EXPERT_MAPPING: + self._transfer_expert_mapping() + self.state = ScaleUpExistingEningeState.WAIT_NEW_WORKERS_WEIGHTS_INIT + return True + + elif state == ScaleUpExistingEningeState.WAIT_NEW_WORKERS_WEIGHTS_INIT: + return False + + elif state == ScaleUpExistingEningeState.TRANSFER_WEIGHTS: + if ( + int(self.old_dp_store.get("eep_barrier_engine_count")) + < self.old_dp_group.size() + ): + return False + if not self._staged_barrier(use_new_group=False): + return False + if self.old_dp_group.rank() == 0: + self.old_dp_store.delete_key("eep_barrier_engine_count") + self._transfer_weights() + self.state = ScaleUpExistingEningeState.SYNC_KV_CACHE_MEMORY_SIZE + return True + + elif state == ScaleUpExistingEningeState.SYNC_KV_CACHE_MEMORY_SIZE: + self._sync_kv_cache_memory_size() + self.state = ScaleUpExistingEningeState.SWITCH_AND_PREPARE + return True + + elif state == ScaleUpExistingEningeState.SWITCH_AND_PREPARE: + self._switch_and_prepare() + self.state = ScaleUpExistingEningeState.EPLB_RESHUFFLE + self.new_dp_store.add("eep_barrier_engine_count", 1) + return True + + elif state == ScaleUpExistingEningeState.EPLB_RESHUFFLE: + if ( + int(self.new_dp_store.get("eep_barrier_engine_count")) + < self.new_dp_group_or_config.size() + ): + return False + if not self._staged_barrier(use_new_group=True): + return False + if self.new_dp_group_or_config.rank() == 0: + self.new_dp_store.delete_key("eep_barrier_engine_count") + self._eplb_reshuffle() + self.state = ScaleUpExistingEningeState.COMPLETE + self._update_parallel_config() + return True + + else: + assert self.state == ScaleUpExistingEningeState.COMPLETE + return True + + def _progress_new_engine(self) -> bool: + state = self.state + + if state == ScaleUpNewEngineState.PREPARE: + tensor = torch.tensor([0, 0, 0], dtype=torch.int32, device="cpu") + torch.distributed.all_reduce( + tensor, + op=torch.distributed.ReduceOp.MAX, + group=self.new_dp_group_or_config, + ) + data = tensor.tolist() + self.engine_core.engines_running = bool(data[0]) + self.engine_core.current_wave = int(data[1]) + self.engine_core.step_counter = int(data[2]) + self.state = ScaleUpNewEngineState.EPLB_RESHUFFLE + self.new_dp_store.add("eep_barrier_engine_count", 1) + return True + + elif state == ScaleUpNewEngineState.EPLB_RESHUFFLE: + if ( + int(self.new_dp_store.get("eep_barrier_engine_count")) + < self.new_dp_group_or_config.size() + ): + return False + if not self._staged_barrier(use_new_group=True): + return False + assert self.new_dp_group_or_config.rank() > 0 + self._eplb_reshuffle() + self.state = ScaleUpNewEngineState.COMPLETE + return True + + else: + assert self.state == ScaleUpNewEngineState.COMPLETE + return True + + def _progress_remaining_engine(self) -> bool: + state = self.state + + if state == ScaleDownRemainingEngineState.PREPARE: + self.state = ScaleDownRemainingEngineState.EPLB_RESHUFFLE + self.old_dp_store.add("eep_barrier_engine_count", 1) + return True + + elif state == ScaleDownRemainingEngineState.EPLB_RESHUFFLE: + if ( + int(self.old_dp_store.get("eep_barrier_engine_count")) + < self.old_dp_group.size() + ): + return False + if not self._staged_barrier(use_new_group=False): + return False + if self.old_dp_group.rank() == 0: + self.old_dp_store.delete_key("eep_barrier_engine_count") + self._eplb_reshuffle_before_scale_down() + self.state = ScaleDownRemainingEngineState.SWITCH_AND_PREPARE + # NOTE(yongji): currently, after EPLB reshuffle + # that redistributes experts to remaining workers, workers + # to be removed will immediately initiate shutdown; + # existing workers can no longer execute forward steps using + # the old setup. In the future, we may keep + # the removing workers alive a bit longer, + # e.g., to drain in-batch requests. + self._create_standby_groups() + self._switch_and_prepare() + self._update_parallel_config() + self.state = ScaleDownRemainingEngineState.COMPLETE + return True + + else: + assert self.state == ScaleDownRemainingEngineState.COMPLETE + return True + + def _progress_removing_engine(self) -> bool: + state = self.state + + if state == ScaleDownRemovingEngineState.PREPARE: + self.state = ScaleDownRemovingEngineState.EPLB_RESHUFFLE + self.old_dp_store.add("eep_barrier_engine_count", 1) + return True + + if state == ScaleDownRemovingEngineState.EPLB_RESHUFFLE: + if ( + int(self.old_dp_store.get("eep_barrier_engine_count")) + < self.old_dp_group.size() + ): + return False + if not self._staged_barrier(use_new_group=False): + return False + assert self.old_dp_group.rank() > 0 + self._eplb_reshuffle_before_scale_down() + self.state = ScaleDownRemovingEngineState.COMPLETE + self.engine_core._eep_send_worker_notification("SHUTDOWN_COMPLETE") + self.engine_core.shutdown() + return True + + else: + assert self.state == ScaleDownRemovingEngineState.COMPLETE + return True + + def handle_notification(self, notification_type: str): + assert self.worker_type != "new" + if ( + notification_type == "NEW_WORKERS_INIT_READY" + and self.state == ScaleUpExistingEningeState.WAIT_NEW_WORKERS_INIT + ): + self.old_dp_store.add("eep_barrier_engine_count", 1) + self.state = ScaleUpExistingEningeState.CREATE_STANDBY_GROUPS + elif ( + notification_type == "NEW_WORKERS_WEIGHTS_INIT_READY" + and self.state == ScaleUpExistingEningeState.WAIT_NEW_WORKERS_WEIGHTS_INIT + ): + self.old_dp_store.add("eep_barrier_engine_count", 1) + self.state = ScaleUpExistingEningeState.TRANSFER_WEIGHTS + + def is_complete(self) -> bool: + if self.scale_type == "scale_up": + return ( + self.state == ScaleUpNewEngineState.COMPLETE + if self.worker_type == "new" + else self.state == ScaleUpExistingEningeState.COMPLETE + ) + return ( + self.state == ScaleDownRemovingEngineState.COMPLETE + if self.worker_type == "shutdown" + else self.state == ScaleDownRemainingEngineState.COMPLETE + ) + + def _create_standby_groups(self): + assert isinstance(self.new_dp_group_or_config, ParallelConfig) + self.new_dp_group_or_config, self.new_dp_store = ( + self.new_dp_group_or_config.stateless_init_dp_group(return_store=True) + ) + self.model_executor.collective_rpc( + "elastic_ep_execute", args=("create_standby_groups", self.reconfig_request) + ) + if self.old_dp_group.rank() == 0: + logger.info("[Elastic EP] Created standby communication groups") + + def _transfer_weights(self): + assert self.reconfig_request is not None + old_dp_size = self.old_dp_group.size() + new_dp_size = self.reconfig_request.new_data_parallel_size + + self.model_executor.collective_rpc( + "elastic_ep_execute", args=("transfer_weights", old_dp_size, new_dp_size) + ) + if self.old_dp_group.rank() == 0: + logger.info("[Elastic EP] Transferred weights to new workers") + + def _transfer_expert_mapping(self): + self.model_executor.collective_rpc( + "elastic_ep_execute", args=("broadcast_expert_mapping",) + ) + if self.old_dp_group.rank() == 0: + logger.info("[Elastic EP] Broadcasted expert mapping to new workers") + + def _sync_kv_cache_memory_size(self): + assert self.engine_core.available_gpu_memory_for_kv_cache > 0 + ParallelConfig.sync_kv_cache_memory_size( + self.new_dp_group_or_config, + self.engine_core.available_gpu_memory_for_kv_cache, + ) + if self.old_dp_group.rank() == 0: + logger.info("[Elastic EP] Synced KV cache memory size to new workers") + + def _switch_and_prepare(self): + self.model_executor.collective_rpc( + "elastic_ep_execute", args=("switch_and_prepare",) + ) + old_dp_group = self.old_dp_group + stateless_destroy_torch_distributed_process_group(old_dp_group) + assert isinstance(self.new_dp_group_or_config, torch.distributed.ProcessGroup) + new_dp_group = self.new_dp_group_or_config + self.engine_core.dp_group = new_dp_group + self.engine_core.dp_rank = new_dp_group.rank() + self.engine_core.dp_store = self.new_dp_store + engines_running = int(self.engine_core.engines_running) + current_wave = self.engine_core.current_wave + step_counter = self.engine_core.step_counter + tensor = torch.tensor( + [engines_running, current_wave, step_counter], + dtype=torch.int32, + device="cpu", + ) + torch.distributed.all_reduce( + tensor, op=torch.distributed.ReduceOp.MAX, group=self.new_dp_group_or_config + ) + data = tensor.tolist() + self.engine_core.engines_running = bool(data[0]) + self.engine_core.current_wave = int(data[1]) + self.engine_core.step_counter = int(data[2]) + if self.new_dp_group_or_config.rank() == 0: + self.engine_core._eep_send_worker_notification("RECONFIGURE_FINISHED") + logger.info("[Elastic EP] Switched to new setup") + + def _eplb_reshuffle(self): + self.model_executor.collective_rpc( + "elastic_ep_execute", args=("perform_eplb_reshuffle",) + ) + assert isinstance(self.new_dp_group_or_config, torch.distributed.ProcessGroup) + if self.new_dp_group_or_config.rank() == 0: + logger.info("[Elastic EP] EPLB reshuffle completed") + + def _eplb_reshuffle_before_scale_down(self): + assert self.reconfig_request is not None + self.model_executor.collective_rpc( + "elastic_ep_execute", + args=( + "perform_eplb_reshuffle", + self.reconfig_request.new_data_parallel_size, + ), + ) + if self.old_dp_group.rank() == 0: + logger.info("[Elastic EP] EPLB reshuffle completed") + + def _update_parallel_config(self): + assert self.reconfig_request is not None + reconfig_request = self.reconfig_request + parallel_config = self.vllm_config.parallel_config + parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size + if ( + reconfig_request.new_data_parallel_rank + != ReconfigureRankType.KEEP_CURRENT_RANK + ): + parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank + if ( + reconfig_request.new_data_parallel_rank_local + != ReconfigureRankType.KEEP_CURRENT_RANK + ): + parallel_config.data_parallel_rank_local = ( + reconfig_request.new_data_parallel_rank_local + ) + parallel_config.data_parallel_master_ip = ( + reconfig_request.new_data_parallel_master_ip + ) + parallel_config.data_parallel_master_port = ( + reconfig_request.new_data_parallel_master_port + ) + parallel_config._data_parallel_master_port_list = ( + reconfig_request.new_data_parallel_master_port_list + ) + parallel_config._stateless_world_group_port_list = ( + reconfig_request.new_stateless_world_group_port_list + ) + parallel_config._stateless_dp_group_port_list = ( + reconfig_request.new_stateless_dp_group_port_list + ) + parallel_config._stateless_ep_group_port_list = ( + reconfig_request.new_stateless_ep_group_port_list + ) diff --git a/vllm/distributed/eplb/async_worker.py b/vllm/distributed/eplb/async_worker.py index e4b4fc92eeaa..7f71725d2716 100644 --- a/vllm/distributed/eplb/async_worker.py +++ b/vllm/distributed/eplb/async_worker.py @@ -24,7 +24,6 @@ def start_async_worker( state: "EplbState", - rank_mapping: dict[int, int] | None = None, is_profile: bool = False, ) -> threading.Thread: ep_group = get_ep_group().device_group @@ -43,7 +42,6 @@ def thread_target() -> None: state=state, ep_group=ep_group, is_profile=is_profile, - rank_mapping=rank_mapping, cuda_stream=cuda_stream, ) ) @@ -61,7 +59,6 @@ async def transfer_run_periodically( state: "EplbState", ep_group: ProcessGroup, is_profile: bool = False, - rank_mapping: dict[int, int] | None = None, cuda_stream: torch.cuda.Stream = None, ) -> None: while True: @@ -99,7 +96,6 @@ async def transfer_run_periodically( is_profile=is_profile, layer=model_state.layer_to_transfer, cuda_stream=cuda_stream, - rank_mapping=rank_mapping, ) event = torch.cuda.Event(blocking=False) cuda_stream.record_event(event) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 9f8798a96a2f..e9928f57d05a 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -40,6 +40,7 @@ get_node_count, in_the_same_node_as, ) +from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.model_executor.models.interfaces import MixtureOfExperts @@ -255,6 +256,14 @@ def __init__(self, parallel_config: ParallelConfig, device: torch.device): """ CUDA device index for the async EPLB worker thread. """ + self.num_valid_physical_experts: int = 0 + """ + Number of valid physical experts. + This is the number of physical experts that are + actually mapped to logical experts. In elastic EP, + newly started EP ranks may not have physical experts + mapped yet. + """ if self.device.type == "cuda": self.cuda_device_index = self.device.index if self.cuda_device_index is None and torch.cuda.is_available(): @@ -320,9 +329,6 @@ def add_model( self, model: MixtureOfExperts, model_config: ModelConfig, - global_expert_load: torch.Tensor | None = None, - old_global_expert_indices: torch.Tensor | None = None, - rank_mapping: dict[int, int] | None = None, ): """ Build the initial EPLB state. @@ -415,71 +421,11 @@ def add_model( ) self.expert_rearrangement_step_interval = eplb_step_interval - if global_expert_load is not None: - ep_group = get_ep_group().device_group - assert global_expert_load.shape == ( - model.num_moe_layers, - model.num_logical_experts, - ) - assert global_expert_load.dtype == torch.int64 - - num_replicas = model.num_physical_experts - num_groups = model.num_expert_groups - num_nodes = get_node_count() - num_gpus = ep_group.size() - - if num_gpus % num_nodes != 0: - num_nodes = 1 - logger.warning_once( - f"num_gpus % num_nodes != 0, " - "not using hierarchical rearrangement algorithm.\n" - f"{num_gpus=}, {num_nodes=}" - ) - - # Get new expert mappings - ( - new_physical_to_logical_map, - new_logical_to_physical_map, - new_logical_replica_count, - ) = rebalance_experts( - global_expert_load, - num_replicas, - num_groups, - num_nodes, - num_gpus, - ) - - max_physical_slots = new_logical_to_physical_map.shape[-1] - assert max_physical_slots <= logical_to_physical_map.shape[-1] - new_logical_to_physical_map = torch.nn.functional.pad( - new_logical_to_physical_map, - (0, logical_to_physical_map.shape[-1] - max_physical_slots), - value=-1, - ) - physical_to_logical_map = new_physical_to_logical_map.to(self.device) - logical_to_physical_map.copy_(new_logical_to_physical_map) - logical_replica_count.copy_(new_logical_replica_count) - else: - new_physical_to_logical_map = None - - new_logical_to_physical_map = None - - new_logical_replica_count = None model.set_eplb_state( expert_load_pass, logical_to_physical_map, logical_replica_count, ) - if global_expert_load is not None: - rearrange_expert_weights_inplace( - old_global_expert_indices, - new_physical_to_logical_map, - model.expert_weights, - ep_group, - False, - rank_mapping, - ) - self.expert_rearrangement_step = 0 expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]] @@ -503,11 +449,12 @@ def add_model( experts_recv_loc={}, is_async_enabled=self.is_async, cuda_device_index=self.cuda_device_index, - new_physical_to_logical_map=new_physical_to_logical_map, - new_logical_to_physical_map=new_logical_to_physical_map, - new_logical_replica_count=new_logical_replica_count, + new_physical_to_logical_map=None, + new_logical_to_physical_map=None, + new_logical_replica_count=None, ) self.model_states[model_config.compute_hash()] = model_state + self.num_valid_physical_experts = model.num_physical_experts def step( self, @@ -651,8 +598,6 @@ def step( def rearrange( self, is_profile: bool = False, - execute_shuffle: bool = True, - global_expert_loads: list[torch.Tensor] | None = None, rank_mapping: dict[int, int] | None = None, ) -> torch.Tensor | None: """ @@ -662,12 +607,6 @@ def rearrange( is_profile (bool): If `True`, perform a dummy rearrangement. This is used in `profile_run` to reserve enough memory, no memory movement will be performed. Default is False. - execute_shuffle (bool): If `True`, execute the shuffle - in elastic expert parallel (EEP). Default is True. - global_expert_loads (list[torch.Tensor] | None): The global expert - loads when scaling is done in EEP. - List of expert loads for the main and drafter - (when spec decode is used) models. rank_mapping (dict[int, int] | None): The rank mapping when scaling is done in EEP. """ @@ -686,67 +625,34 @@ def rearrange( "(profile)" if is_profile else "", ) - if global_expert_loads is None: - # Map the physical expert load to global logical experts - global_expert_load_windows = [] - if not execute_shuffle: - num_models = torch.tensor( - [len(self.model_states)], dtype=torch.int32, device="cpu" - ) - torch.distributed.broadcast( - num_models, group=get_ep_group().cpu_group, group_src=0 - ) - - for eplb_model_state in self.model_states.values(): - logical_expert_load_window = torch.zeros( - self.expert_load_window_size, - eplb_model_state.model.num_moe_layers, - eplb_model_state.model.num_logical_experts, - dtype=eplb_model_state.expert_load_window.dtype, - device=eplb_model_state.expert_load_window.device, - ) - logical_expert_load_window.scatter_add_( - dim=-1, - index=eplb_model_state.physical_to_logical_map.unsqueeze(0) - .expand_as(eplb_model_state.expert_load_window) - .long(), - src=eplb_model_state.expert_load_window, - ) - - if not execute_shuffle: - metadata = torch.tensor( - [ - eplb_model_state.model.num_moe_layers, - eplb_model_state.model.num_logical_experts, - eplb_model_state.physical_to_logical_map.shape[1], - ], - dtype=torch.int32, - device="cpu", - ) - torch.distributed.broadcast( - metadata, group=get_ep_group().cpu_group, group_src=0 - ) - - global_expert_load_window = logical_expert_load_window.sum(dim=0) - global_expert_load_windows.append(global_expert_load_window) - # Perform all-reduce to get the expert load across all ranks for each model - global_expert_load_windows = self._allreduce_list( - global_expert_load_windows + # Map the physical expert load to global logical experts + global_expert_load_windows = [] + for eplb_model_state in self.model_states.values(): + expert_load_window = eplb_model_state.expert_load_window[ + :, :, : self.num_valid_physical_experts + ] + logical_expert_load_window = torch.zeros( + self.expert_load_window_size, + eplb_model_state.model.num_moe_layers, + eplb_model_state.model.num_logical_experts, + dtype=eplb_model_state.expert_load_window.dtype, + device=eplb_model_state.expert_load_window.device, + ) + logical_expert_load_window.scatter_add_( + dim=-1, + index=eplb_model_state.physical_to_logical_map[ + :, : self.num_valid_physical_experts + ] + .unsqueeze(0) + .expand_as(expert_load_window) + .long(), + src=expert_load_window, ) - if not execute_shuffle: - for eplb_model_state, global_expert_load_window in zip( - self.model_states.values(), global_expert_load_windows - ): - # (num_moe_layers, old_num_physical_experts) - old_global_expert_indices = eplb_model_state.physical_to_logical_map - torch.distributed.broadcast( - old_global_expert_indices, group=ep_group, group_src=0 - ) - if not execute_shuffle: - return global_expert_load_windows - else: - assert execute_shuffle - global_expert_load_windows = global_expert_loads + + global_expert_load_window = logical_expert_load_window.sum(dim=0) + global_expert_load_windows.append(global_expert_load_window) + # Perform all-reduce to get the expert load across all ranks for each model + global_expert_load_windows = self._allreduce_list(global_expert_load_windows) # TODO(bowen): Treat differently for prefill and decode nodes eplb_model_state = next(iter(self.model_states.values())) @@ -758,8 +664,10 @@ def rearrange( # NOTE(yongji): scale down, we need to rebalance the experts on # remaining GPUs, transfer the experts while we haven't shutdown # the GPUs to be released. - cpu_group = get_ep_group().cpu_group - num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping) + coordinator = get_ep_group() + assert isinstance(coordinator, StatelessGroupCoordinator) + tcp_store_group = coordinator.tcp_store_group + num_nodes = _node_count_with_rank_mapping(tcp_store_group, rank_mapping) num_gpus = sum(new_rank != -1 for new_rank in rank_mapping.values()) num_replicas = ( num_replicas // ep_group.size() * num_gpus @@ -882,7 +790,6 @@ def start_async_loop( if self.async_worker is None: self.async_worker = start_async_worker( self, - rank_mapping=rank_mapping, is_profile=is_profile, ) @@ -1002,83 +909,6 @@ def post_eplb(self, model_state: EplbModelState, is_profile: bool = False) -> No model_state.new_logical_to_physical_map = None model_state.new_logical_replica_count = None - @staticmethod - def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]: - """ - Receive the expert load and old placement from the master rank. - """ - ep_group = get_ep_group() - num_models = torch.empty(1, dtype=torch.int32, device="cpu") - torch.distributed.broadcast(num_models, group=ep_group.cpu_group, group_src=0) - num_models = num_models.item() - global_expert_loads = [] - old_global_expert_indices_per_model = [] - for _ in range(num_models): - metadata = torch.empty(3, dtype=torch.int32, device="cpu") - torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0) - num_moe_layers, num_logical_experts, num_old_physical_experts = ( - metadata.tolist() - ) - global_expert_load = torch.zeros( - (num_moe_layers, num_logical_experts), - dtype=torch.int64, - device=ep_group.device, - ) - all_reduce(global_expert_load, group=ep_group.device_group) - old_global_expert_indices = torch.empty( - (num_moe_layers, num_old_physical_experts), - dtype=torch.int64, - device=ep_group.device, - ) - torch.distributed.broadcast( - old_global_expert_indices, - group=ep_group.device_group, - group_src=0, - ) - global_expert_loads.append(global_expert_load) - old_global_expert_indices_per_model.append(old_global_expert_indices) - return global_expert_loads, old_global_expert_indices_per_model - - @classmethod - def get_eep_state( - cls, parallel_config: ParallelConfig - ) -> tuple[ - list[torch.Tensor] | None, - list[torch.Tensor] | None, - dict[int, int] | None, - ]: - num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu") - torch.distributed.broadcast( - num_local_physical_experts, - group=get_ep_group().cpu_group, - group_src=0, - ) - num_local_physical_experts = int(num_local_physical_experts.item()) - new_ep_size = get_ep_group().world_size - global_expert_loads, old_global_expert_indices_per_model = ( - EplbState.recv_state() - ) - - # EP configuration for all models has to be the same so as eplb config - num_logical_experts = global_expert_loads[0].shape[1] - parallel_config.eplb_config.num_redundant_experts = ( - num_local_physical_experts * new_ep_size - num_logical_experts - ) - assert ( - old_global_expert_indices_per_model[0].shape[1] % num_local_physical_experts - == 0 - ) - old_ep_size = ( - old_global_expert_indices_per_model[0].shape[1] - // num_local_physical_experts - ) - rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)} - return ( - global_expert_loads, - old_global_expert_indices_per_model, - rank_mapping, - ) - def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]: """ All-reduce a list of tensors. @@ -1116,6 +946,60 @@ def _sync_load_pass(self) -> list[torch.Tensor]: load_pass_list.append(eplb_model_state.expert_load_pass.clone()) return self._allreduce_list(load_pass_list) + @classmethod + def from_mapping( + cls, + model: MixtureOfExperts, + model_config: ModelConfig, + device: torch.device, + parallel_config: ParallelConfig, + expanded_physical_to_logical: torch.Tensor, + num_valid_physical_experts: int, + ) -> "EplbState": + eplb_state = cls( + parallel_config=parallel_config, + device=device, + ) + eplb_state.add_model( + model=model, + model_config=model_config, + ) + eplb_state.num_valid_physical_experts = num_valid_physical_experts + num_moe_layers = expanded_physical_to_logical.shape[0] + num_physical_experts = expanded_physical_to_logical.shape[1] + eplb_model_state = eplb_state.model_states[model_config.compute_hash()] + eplb_model_state.physical_to_logical_map.copy_(expanded_physical_to_logical) + + logical_to_physical_map = torch.full( + ( + num_moe_layers, + model.num_logical_experts, + eplb_model_state.logical_to_physical_map.shape[2], + ), + -1, + dtype=torch.int64, + ) + logical_replica_count = torch.zeros( + (num_moe_layers, model.num_logical_experts), + dtype=torch.int64, + ) + expanded_physical_to_logical_numpy = expanded_physical_to_logical.cpu().numpy() + for layer_idx in range(num_moe_layers): + for phys_idx in range(num_physical_experts): + logical_idx = expanded_physical_to_logical_numpy[layer_idx, phys_idx] + if logical_idx >= 0: + replica_idx = logical_replica_count[layer_idx, logical_idx] + logical_to_physical_map[layer_idx, logical_idx, replica_idx] = ( + phys_idx + ) + logical_replica_count[layer_idx, logical_idx] += 1 + + logical_to_physical_map = logical_to_physical_map.to(device) + logical_replica_count = logical_replica_count.to(device) + eplb_model_state.logical_to_physical_map.copy_(logical_to_physical_map) + eplb_model_state.logical_replica_count.copy_(logical_replica_count) + return eplb_state + def _node_count_with_rank_mapping( pg: ProcessGroup | StatelessProcessGroup, diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 376dad8a72ef..8002c8343683 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -17,6 +17,9 @@ batch_isend_irecv, get_global_rank, ) +from torch.distributed.distributed_c10d import _world + +from vllm.distributed.parallel_state import get_ep_group def idx_local_to_global( @@ -142,6 +145,11 @@ def move_to_buffer( buffer[dst].copy_(weight[src], non_blocking=True) p2p_ops: list[P2POp] = [] + if ep_group not in _world.pg_map: + ep_group = get_ep_group() + is_stateless = True + else: + is_stateless = False # 2. Initiate sending of weights. experts_send_loc: dict[int, int] = {} @@ -176,15 +184,23 @@ def move_to_buffer( recv_ranks.append(ranks_to_recv[recver_pos]) for dst in recv_ranks: - dst_global = get_global_rank(ep_group, dst) - p2p_ops += [ - P2POp( - torch.distributed.isend, - weight[src], - dst_global, - ) - for weight in expert_weights - ] + if is_stateless: + for weight in expert_weights: + op = object.__new__(P2POp) + op.op = torch.distributed.isend + op.tensor = weight[src] + op.group_peer = dst + p2p_ops.append(op) + else: + dst_global = get_global_rank(ep_group, dst) + p2p_ops += [ + P2POp( + torch.distributed.isend, + weight[src], + dst_global, + ) + for weight in expert_weights + ] # 3. Initiate receiving of weights. experts_recv_loc: dict[int, int] = {} @@ -216,26 +232,40 @@ def move_to_buffer( else: src = ranks_to_send[recver_pos - remainder_start] - src_global = get_global_rank(ep_group, src) - p2p_ops += [ - P2POp( - torch.distributed.irecv, - weight[dst], - src_global, - ) - for weight in expert_weights_buffer - ] + if is_stateless: + for weight in expert_weights_buffer: + op = object.__new__(P2POp) + op.op = torch.distributed.irecv + op.tensor = weight[dst] + op.group_peer = src + p2p_ops.append(op) + else: + src_global = get_global_rank(ep_group, src) + p2p_ops += [ + P2POp( + torch.distributed.irecv, + weight[dst], + src_global, + ) + for weight in expert_weights_buffer + ] # 4. Execute the P2P operations. The real communication happens here. if p2p_ops and cuda_stream is not None: with torch.cuda.stream(cuda_stream): + if is_stateless: + ep_group.device_communicator.batch_isend_irecv(p2p_ops) + else: + reqs = batch_isend_irecv(p2p_ops) + for req in reqs: + req.wait() + elif p2p_ops: + if is_stateless: + ep_group.device_communicator.batch_isend_irecv(p2p_ops) + else: reqs = batch_isend_irecv(p2p_ops) for req in reqs: req.wait() - elif p2p_ops: - reqs = batch_isend_irecv(p2p_ops) - for req in reqs: - req.wait() # wait for the communication to finish return is_unchanged, is_received_locally, experts_recv_loc diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 52b433cfaf1b..9fb29b184f62 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -33,7 +33,7 @@ from dataclasses import dataclass from datetime import timedelta from multiprocessing import shared_memory -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from unittest.mock import patch import torch @@ -56,6 +56,9 @@ supports_custom_op, ) +if TYPE_CHECKING: + from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator + @dataclass class GraphCaptureContext: @@ -1024,7 +1027,7 @@ def combine( return hidden_states -_WORLD: GroupCoordinator | None = None +_WORLD: "GroupCoordinator | StatelessGroupCoordinator | None" = None _INNER_DP_WORLD: GroupCoordinator | None = None _NODE_COUNT: int | None = None @@ -1091,12 +1094,12 @@ def get_dcp_group() -> GroupCoordinator: _PP: GroupCoordinator | None = None -def get_pp_group() -> GroupCoordinator: +def get_pp_group() -> "GroupCoordinator | StatelessGroupCoordinator": assert _PP is not None, "pipeline model parallel group is not initialized" return _PP -_DP: GroupCoordinator | None = None +_DP: "GroupCoordinator | StatelessGroupCoordinator | None" = None def get_dp_group() -> GroupCoordinator: @@ -1104,10 +1107,10 @@ def get_dp_group() -> GroupCoordinator: return _DP -_EP: GroupCoordinator | None = None +_EP: "GroupCoordinator | StatelessGroupCoordinator | None" = None -def get_ep_group() -> GroupCoordinator: +def get_ep_group() -> "GroupCoordinator | StatelessGroupCoordinator": assert _EP is not None, "expert parallel group is not initialized" return _EP @@ -1120,6 +1123,24 @@ def get_pcp_group() -> GroupCoordinator: return _PCP +_STANDBY_DP: "StatelessGroupCoordinator | None" = None +_STANDBY_EP: "StatelessGroupCoordinator | None" = None +_STANDBY_WORLD: "StatelessGroupCoordinator | None" = None +_STANDBY_WORLD_NODE_COUNT: int | None = None + + +def get_standby_dp_group() -> "StatelessGroupCoordinator | None": + return _STANDBY_DP + + +def get_standby_ep_group() -> "StatelessGroupCoordinator | None": + return _STANDBY_EP + + +def get_standby_world_group() -> "StatelessGroupCoordinator | None": + return _STANDBY_WORLD + + @contextmanager def graph_capture(device: torch.device): """ @@ -1169,6 +1190,7 @@ def init_distributed_environment( from vllm.config import get_current_vllm_config config = get_current_vllm_config() + enable_elastic_ep = config is not None and config.parallel_config.enable_elastic_ep if config is not None and config.parallel_config.nnodes > 1: parallel_config = config.parallel_config ip = parallel_config.master_addr @@ -1180,6 +1202,7 @@ def init_distributed_environment( config is not None and config.parallel_config.data_parallel_size > 1 and config.parallel_config.distributed_executor_backend != "external_launcher" + and not enable_elastic_ep ): parallel_config = config.parallel_config # adjust to take into account data parallelism @@ -1226,6 +1249,18 @@ def init_distributed_environment( rank=rank, timeout=timeout, ) + if enable_elastic_ep: + tp_pp_cpu_group = torch.distributed.new_group( + backend="gloo", timeout=timeout + ) + if _node_count(tp_pp_cpu_group) > 1: + # NOTE(yongji): StatelessGroupCoordinator uses data_parallel_master_ip + # to initialize all DP/EP groups, hence all ranks within TP/PP group + # must reside on the same node + raise RuntimeError( + "Elastic EP is not yet supported with multi-node TP/PP" + ) + # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 @@ -1234,6 +1269,35 @@ def init_distributed_environment( # setting, where we can use rank as local rank local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank global _WORLD, _NODE_COUNT, _INNER_DP_WORLD + if enable_elastic_ep: + from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator + + # Create stateless world group with all ranks + assert _WORLD is None, "world group already initialized" + parallel_config = config.parallel_config + global_rank = parallel_config.data_parallel_rank * world_size + rank + global_world_size = parallel_config.world_size_across_dp + all_ranks = list(range(global_world_size)) + group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)] + if global_rank in all_ranks: + group_ranks = [all_ranks] + group_ports = [parallel_config.get_next_stateless_world_group_port()] + _WORLD = StatelessGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_device_communicator=False, + group_name="world", + host=parallel_config.data_parallel_master_ip, + group_ports=group_ports, + global_rank=global_rank, + global_world_size=global_world_size, + ) + assert config.parallel_config.nnodes_within_dp == 1, ( + "Elastic EP is not supported with multi-node TP/PP" + ) + _NODE_COUNT = _node_count(_WORLD.tcp_store_group) # type: ignore[union-attr] + return if _WORLD is None: ranks = list(range(torch.distributed.get_world_size())) _WORLD = init_world_group(ranks, local_rank, backend) @@ -1297,9 +1361,6 @@ def initialize_model_parallel( """ # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() - world_size: int = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - backend = backend or torch.distributed.get_backend(get_world_group().device_group) data_parallel_size = 1 from vllm.config import get_current_vllm_config @@ -1308,6 +1369,19 @@ def initialize_model_parallel( if config is not None: data_parallel_size = config.parallel_config.data_parallel_size + enable_elastic_ep = config is not None and config.parallel_config.enable_elastic_ep + if enable_elastic_ep: + # Use stateless world group for global information + world_size = get_world_group().world_size + rank = get_world_group().rank + backend = backend or "nccl" + else: + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + backend = backend or torch.distributed.get_backend( + get_world_group().device_group + ) + # the layout order is: ExternalDP x DP x PP x TP # ExternalDP is the data parallel group that is not part of the model, # every dp rank can generate independently (in verl integration). @@ -1330,7 +1404,19 @@ def initialize_model_parallel( assert _TP is None, "tensor model parallel group is already initialized" group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] - + if enable_elastic_ep: + tp_pp_pcp_size = ( + tensor_model_parallel_size + * pipeline_model_parallel_size + * prefill_context_model_parallel_size + ) + local_all_ranks = torch.arange(tp_pp_pcp_size).reshape( + pipeline_model_parallel_size, + prefill_context_model_parallel_size, + tensor_model_parallel_size, + ) + group_ranks = local_all_ranks.view(-1, tensor_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] # message queue broadcaster is only used in tensor model parallel group _TP = init_model_parallel_group( group_ranks, @@ -1349,6 +1435,21 @@ def initialize_model_parallel( # TP group into tp_size//dcp_size DCP groups. group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] + if enable_elastic_ep: + tp_pp_pcp_size = ( + tensor_model_parallel_size + * pipeline_model_parallel_size + * prefill_context_model_parallel_size + ) + local_all_ranks = torch.arange(tp_pp_pcp_size).reshape( + pipeline_model_parallel_size, + prefill_context_model_parallel_size, + tensor_model_parallel_size, + ) + group_ranks = local_all_ranks.reshape( + -1, decode_context_model_parallel_size + ).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] _DCP = init_model_parallel_group( group_ranks, get_world_group().local_rank, @@ -1365,6 +1466,23 @@ def initialize_model_parallel( .unbind(0) ) group_ranks = [x.tolist() for x in group_ranks] + if enable_elastic_ep: + tp_pp_pcp_size = ( + tensor_model_parallel_size + * pipeline_model_parallel_size + * prefill_context_model_parallel_size + ) + local_all_ranks = torch.arange(tp_pp_pcp_size).reshape( + pipeline_model_parallel_size, + prefill_context_model_parallel_size, + tensor_model_parallel_size, + ) + group_ranks = ( + local_all_ranks.transpose(1, 2) + .reshape(-1, prefill_context_model_parallel_size) + .unbind(0) + ) + group_ranks = [x.tolist() for x in group_ranks] _PCP = init_model_parallel_group( group_ranks, get_world_group().local_rank, backend, group_name="pcp" ) @@ -1376,6 +1494,23 @@ def initialize_model_parallel( all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0) ) group_ranks = [x.tolist() for x in group_ranks] + if enable_elastic_ep: + tp_pp_pcp_size = ( + tensor_model_parallel_size + * pipeline_model_parallel_size + * prefill_context_model_parallel_size + ) + local_all_ranks = torch.arange(tp_pp_pcp_size).reshape( + pipeline_model_parallel_size, + prefill_context_model_parallel_size, + tensor_model_parallel_size, + ) + group_ranks = ( + local_all_ranks.transpose(0, 2) + .reshape(-1, pipeline_model_parallel_size) + .unbind(0) + ) + group_ranks = [x.tolist() for x in group_ranks] _PP = init_model_parallel_group( group_ranks, get_world_group().local_rank, backend, group_name="pp" ) @@ -1384,9 +1519,28 @@ def initialize_model_parallel( assert _DP is None, "data parallel group is already initialized" group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] - _DP = init_model_parallel_group( - group_ranks, get_world_group().local_rank, backend, group_name="dp" - ) + if enable_elastic_ep: + from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator + + parallel_config = config.parallel_config + group_ports = [ + parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks + ] + _DP = StatelessGroupCoordinator( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + torch_distributed_backend=backend, + use_device_communicator=True, + group_name="dp", + host=parallel_config.data_parallel_master_ip, + group_ports=group_ports, + global_rank=get_world_group().rank, + global_world_size=get_world_group().world_size, + ) + else: + _DP = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, group_name="dp" + ) global _EP assert _EP is None, "expert parallel group is already initialized" @@ -1401,9 +1555,28 @@ def initialize_model_parallel( .unbind(0) ) group_ranks = [x.tolist() for x in group_ranks] - _EP = init_model_parallel_group( - group_ranks, get_world_group().local_rank, backend, group_name="ep" - ) + if enable_elastic_ep: + from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator + + parallel_config = config.parallel_config + group_ports = [ + parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks + ] + _EP = StatelessGroupCoordinator( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + torch_distributed_backend=backend, + use_device_communicator=True, + group_name="ep", + host=parallel_config.data_parallel_master_ip, + group_ports=group_ports, + global_rank=get_world_group().rank, + global_world_size=get_world_group().world_size, + ) + else: + _EP = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, group_name="ep" + ) logger.info_once( "rank %s in world size %s is assigned as " @@ -1430,7 +1603,13 @@ def ensure_model_parallel_initialized( or ensure tensor-parallel and pipeline-parallel sizes are equal to expected values if the model parallel groups are initialized. """ - backend = backend or torch.distributed.get_backend(get_world_group().device_group) + from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator + + world_group = get_world_group() + if isinstance(world_group, StatelessGroupCoordinator): + backend = backend or world_group.backend + else: + backend = backend or torch.distributed.get_backend(world_group.device_group) if not model_parallel_is_initialized(): initialize_model_parallel( tensor_model_parallel_size, @@ -1460,6 +1639,96 @@ def ensure_model_parallel_initialized( ) +def create_standby_groups( + new_dp_size: int, + new_world_size_across_dp: int, + master_ip: str, + world_group_ports: list[list[int]], + dp_group_ports: list[list[int]], + ep_group_ports: list[list[int]], + backend: str | None = None, +) -> None: + from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator + + global _STANDBY_WORLD, _STANDBY_WORLD_NODE_COUNT, _STANDBY_DP, _STANDBY_EP + + assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size + world_group = get_world_group() + assert isinstance(world_group, StatelessGroupCoordinator) + backend = backend or world_group.backend + local_rank = world_group.local_rank + global_rank = world_group.rank + + standby_world_ranks = [list(range(new_world_size_across_dp))] + _STANDBY_WORLD = StatelessGroupCoordinator( + group_ranks=standby_world_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_device_communicator=False, + group_name="world", + host=master_ip, + group_ports=world_group_ports, + global_rank=global_rank, + global_world_size=new_world_size_across_dp, + ) + _STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group) + + tp_size = get_tp_group().world_size + pp_size = get_pp_group().world_size + + all_ranks = torch.arange(new_world_size_across_dp).reshape( + -1, new_dp_size, pp_size, tp_size + ) + standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0) + standby_dp_ranks = [x.tolist() for x in standby_dp_ranks] + _STANDBY_DP = StatelessGroupCoordinator( + group_ranks=standby_dp_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_device_communicator=True, + group_name="dp", + host=master_ip, + group_ports=dp_group_ports, + global_rank=global_rank, + global_world_size=new_world_size_across_dp, + ) + + standby_ep_ranks = ( + all_ranks.transpose(1, 2).reshape(-1, new_dp_size * tp_size).unbind(0) + ) + standby_ep_ranks = [x.tolist() for x in standby_ep_ranks] + _STANDBY_EP = StatelessGroupCoordinator( + group_ranks=standby_ep_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_device_communicator=True, + group_name="ep", + host=master_ip, + group_ports=ep_group_ports, + global_rank=global_rank, + global_world_size=new_world_size_across_dp, + ) + + +def switch_to_standby_groups() -> None: + global _WORLD, _STANDBY_WORLD, _NODE_COUNT, _STANDBY_WORLD_NODE_COUNT + global _DP, _EP, _STANDBY_DP, _STANDBY_EP + assert _DP is not None + assert _EP is not None + assert _WORLD is not None + _DP.destroy() + _EP.destroy() + _WORLD.destroy() + _DP = _STANDBY_DP + _EP = _STANDBY_EP + _WORLD = _STANDBY_WORLD + _NODE_COUNT = _STANDBY_WORLD_NODE_COUNT + _STANDBY_DP = None + _STANDBY_EP = None + _STANDBY_WORLD = None + _STANDBY_WORLD_NODE_COUNT = None + + def prepare_communication_buffer_for_model(model: torch.nn.Module): """Prepare the communication buffer for the model. Traditional communication libraries like NCCL are almost diff --git a/vllm/distributed/stateless_coordinator.py b/vllm/distributed/stateless_coordinator.py new file mode 100644 index 000000000000..b8173d7afa07 --- /dev/null +++ b/vllm/distributed/stateless_coordinator.py @@ -0,0 +1,318 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional + +import torch +from torch.distributed import Backend, ProcessGroup + +from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator +from vllm.distributed.parallel_state import ( + GroupCoordinator, + TensorMetadata, + _get_unique_name, + _register_group, + _split_tensor_dict, +) +from vllm.distributed.utils import ( + StatelessProcessGroup, + stateless_destroy_torch_distributed_process_group, + stateless_init_torch_distributed_process_group, +) +from vllm.logger import init_logger +from vllm.utils.import_utils import resolve_obj_by_qualname + +logger = init_logger(__name__) + + +class StatelessGroupCoordinator(GroupCoordinator): + """ + A stateless version of the GroupCoordinator class in parallel_state, + It will create CPU, device and TCPStore based communication groups + that are independent of PyTorch's WORLD group. Hence, + communication groups with a different set of participants GPUs + can be created without destroying the existing ones. + """ + + def __init__( + self, + group_ranks: list[list[int]], + local_rank: int, + torch_distributed_backend: str | Backend, + use_device_communicator: bool, + use_message_queue_broadcaster: bool = False, + group_name: str | None = None, + host: str = "127.0.0.1", + group_ports: list[list[int]] | None = None, + global_rank: int = 0, + global_world_size: int = 1, + ): + group_name = group_name or "anonymous" + self.unique_name = _get_unique_name(group_name) + _register_group(self) + + self.rank = global_rank + self.local_rank = local_rank + + self_device_group = None + self_cpu_group = None + self_tcp_store_group = None + + from vllm.platforms import current_platform + + backend = str(torch_distributed_backend) + self.backend = backend + assert group_ports is not None, "group_ports is not provided" + for idx, ranks in enumerate(group_ranks): + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + + ports = group_ports[idx] + device_port = ports[0] + cpu_port = ports[1] + tcp_store_port = ports[2] + + device_group = stateless_init_torch_distributed_process_group( + host=host, + port=device_port, + rank=self.rank_in_group, + world_size=self.world_size, + backend=backend, + group_name=f"{self.unique_name}_device", + ) + cpu_group = stateless_init_torch_distributed_process_group( + host=host, + port=cpu_port, + rank=self.rank_in_group, + world_size=self.world_size, + backend="gloo", + group_name=f"{self.unique_name}_cpu", + ) + tcp_store_group = StatelessProcessGroup.create( + host=host, + port=tcp_store_port, + rank=self.rank_in_group, + world_size=self.world_size, + ) + + self_device_group = device_group + self_cpu_group = cpu_group + self_tcp_store_group = tcp_store_group + + assert self_cpu_group is not None + assert self_device_group is not None + assert self_tcp_store_group is not None + + self.cpu_group = self_cpu_group + self.device_group = self_device_group + self.tcp_store_group = self_tcp_store_group + + if current_platform.is_cuda_alike(): + self.device = torch.device(f"cuda:{local_rank}") + elif current_platform.is_xpu(): + self.device = torch.device(f"xpu:{local_rank}") + elif current_platform.is_out_of_tree(): + self.device = torch.device(f"{current_platform.device_name}:{local_rank}") + else: + self.device = torch.device("cpu") + + self.use_device_communicator = use_device_communicator + self.device_communicator = None + if use_device_communicator and self.world_size > 1: + device_comm_cls = resolve_obj_by_qualname( + current_platform.get_device_communicator_cls() + ) + assert device_comm_cls == CudaCommunicator + self.device_communicator = CudaCommunicator( + cpu_group=self.cpu_group, + device=self.device, + device_group=self.device_group, + unique_name=self.unique_name, + global_ranks=self.ranks, + global_world_size=global_world_size, + tcp_store_group=self.tcp_store_group, + ) + + self.mq_broadcaster = None + + self.use_custom_op_call = ( + current_platform.is_cuda_alike() or current_platform.is_tpu() + ) + self.use_cpu_custom_send_recv = False + + def destroy(self): + if self.device_communicator: + self.device_communicator.destroy() + if self.device_group: + stateless_destroy_torch_distributed_process_group(self.device_group) + if self.cpu_group: + stateless_destroy_torch_distributed_process_group(self.cpu_group) + + def broadcast(self, input_: torch.Tensor, src: int = 0): + if self.world_size == 1: + return input_ + + if self.device_communicator and input_.is_cuda: + return self.device_communicator.broadcast(input_, src) + else: + return self.tcp_store_group.broadcast(input_, src) + + def broadcast_object(self, obj=None, src: int = 0): + if self.world_size == 1: + return obj + return self.tcp_store_group.broadcast_obj(obj, src) + + def broadcast_object_list( + self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None + ): + assert src < self.world_size + + if self.world_size == 1: + return obj_list + + if self.rank_in_group == src: + for obj in obj_list: + self.tcp_store_group.broadcast_obj(obj, src) + else: + for i in range(len(obj_list)): + obj_list[i] = self.tcp_store_group.broadcast_obj(None, src) + + return obj_list + + def broadcast_tensor_dict( + self, + tensor_dict: dict[str, torch.Tensor | Any] | None = None, + src: int = 0, + group: ProcessGroup | None = None, + metadata_group: ProcessGroup | None = None, + ) -> dict[str, torch.Tensor | Any] | None: + if self.world_size == 1: + return tensor_dict + + if self.rank_in_group == src: + assert isinstance(tensor_dict, dict), ( + f"Expecting a dictionary, got {type(tensor_dict)}" + ) + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + else: + metadata_list = None + tensor_list = [] + + recv_metadata_list: list[tuple[str, Any]] = self.tcp_store_group.broadcast_obj( + metadata_list, src + ) + + if self.rank_in_group != src: + tensor_dict = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty( + value.size, dtype=value.dtype, device=value.device + ) + tensor_list.append(tensor) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + + for tensor in tensor_list: + if tensor.numel() == 0: + continue + if self.device_communicator and tensor.is_cuda: + tensor.copy_(self.device_communicator.broadcast(tensor, src)) + else: + tensor.copy_(self.tcp_store_group.broadcast(tensor, src)) + + return tensor_dict + + def send_object(self, obj, dst: int) -> None: + assert dst < self.world_size + assert dst != self.rank_in_group + self.tcp_store_group.send_obj(obj, dst) + + def recv_object(self, src: int): + assert src < self.world_size + assert src != self.rank_in_group + return self.tcp_store_group.recv_obj(src) + + def send_tensor_dict( + self, + tensor_dict: dict[str, torch.Tensor | Any], + dst: int | None = None, + all_gather_group: Optional["GroupCoordinator"] = None, + all_gather_tensors: dict[str, bool] | None = None, + ) -> dict[str, torch.Tensor | Any] | None: + if self.world_size == 1: + return tensor_dict + + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + assert dst < self.world_size + + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + self.tcp_store_group.send_obj(metadata_list, dst) + + for tensor in tensor_list: + if tensor.numel() == 0: + continue + if self.device_communicator and tensor.is_cuda: + self.device_communicator.send(tensor, dst) + else: + self.tcp_store_group.send(tensor, dst) + + return None + + def recv_tensor_dict( + self, + src: int | None = None, + all_gather_group: Optional["GroupCoordinator"] = None, + all_gather_tensors: dict[str, bool] | None = None, + ) -> dict[str, torch.Tensor | Any] | None: + if self.world_size == 1: + return None + + if src is None: + src = (self.rank_in_group - 1) % self.world_size + assert src < self.world_size + + recv_metadata_list = self.tcp_store_group.recv_obj(src) + tensor_dict = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() > 0: + if self.device_communicator and tensor.is_cuda: + tensor = self.device_communicator.recv( + tensor.size(), tensor.dtype, src + ) + else: + tensor = self.tcp_store_group.recv(tensor, src) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + return tensor_dict + + def barrier(self): + self.tcp_store_group.barrier() + + def gather( + self, input_: torch.Tensor, dst: int = 0, dim: int = -1 + ) -> torch.Tensor | None: + if self.world_size == 1: + return input_ + + if self.device_communicator is None: + raise ValueError("No device communicator found") + + if self.rank_in_group == dst: + gathered_list = [torch.empty_like(input_) for _ in range(self.world_size)] + gathered_list[self.rank_in_group] = input_ + for src_rank in range(self.world_size): + if src_rank != self.rank_in_group: + gathered_list[src_rank] = self.device_communicator.recv( + input_.size(), input_.dtype, src_rank + ) + return torch.cat(gathered_list, dim=dim) + else: + self.device_communicator.send(input_, dst) + return None diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 242ce393e4dc..6b66b11e6b28 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -18,7 +18,7 @@ from typing import Any import torch -from torch.distributed import ProcessGroup, TCPStore +from torch.distributed import ProcessGroup, Store, TCPStore from torch.distributed.distributed_c10d import ( Backend, PrefixStore, @@ -229,6 +229,54 @@ def all_gather_obj(self, obj: Any) -> list[Any]: gathered_objs.append(recv_obj) return gathered_objs + def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: + """Broadcast a tensor from source rank to all other ranks.""" + if self.rank == src: + tensor_bytes = pickle.dumps(tensor) + self.expire_data() + key = f"broadcast_tensor/{src}/{self.broadcast_send_counter}" + self.store.set(key, tensor_bytes) + self.broadcast_send_counter += 1 + self.entries.append((key, time.time())) + return tensor + else: + key = f"broadcast_tensor/{src}/{self.broadcast_recv_src_counter[src]}" + tensor = pickle.loads(self.store.get(key)) + self.broadcast_recv_src_counter[src] += 1 + return tensor + + def send(self, tensor: torch.Tensor, dst: int): + """Send a tensor to a destination rank.""" + self.expire_data() + key = f"send_tensor/{dst}/{self.send_dst_counter[dst]}" + self.store.set(key, pickle.dumps(tensor)) + self.send_dst_counter[dst] += 1 + self.entries.append((key, time.time())) + + def recv(self, tensor: torch.Tensor, src: int): + """Receive a tensor from a source rank.""" + key = f"send_tensor/{self.rank}/{self.recv_src_counter[src]}" + received = pickle.loads(self.store.get(key)) + self.recv_src_counter[src] += 1 + tensor.copy_(received) + + def all_reduce( + self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM + ) -> torch.Tensor: + """All-reduce a tensor across all ranks.""" + tensors = self.all_gather_obj(tensor) + result = tensors[0].clone() + for t in tensors[1:]: + if op == torch.distributed.ReduceOp.SUM: + result.add_(t) + elif op == torch.distributed.ReduceOp.PRODUCT: + result.mul_(t) + elif op == torch.distributed.ReduceOp.MAX: + result = torch.maximum(result, t) + elif op == torch.distributed.ReduceOp.MIN: + result = torch.minimum(result, t) + return result + def barrier(self, timeout: float = 30.0): """A robust barrier to synchronize all ranks. @@ -460,8 +508,14 @@ def init_gloo_process_group( def stateless_init_torch_distributed_process_group( - host: str, port: int, rank: int, world_size: int, backend: str -) -> ProcessGroup: + host: str, + port: int, + rank: int, + world_size: int, + backend: str, + group_name: str | None = None, + return_store: bool = False, +) -> ProcessGroup | tuple[ProcessGroup, Store]: """ A replacement for `torch.distributed.init_process_group` that does not pollute the global state. The created ProcessGroup object can be used for @@ -508,26 +562,36 @@ def stateless_init_torch_distributed_process_group( # Use a PrefixStore to avoid accidental overrides of keys used by # different systems (e.g. RPC) in case the store is multi-tenant. prefix_store = PrefixStore(init_method, store) - try: - from vllm.platforms import current_platform - return current_platform.stateless_init_device_torch_dist_pg( - backend=backend, + if backend == "gloo": + pg = init_gloo_process_group( prefix_store=prefix_store, group_rank=group_rank, group_size=group_size, timeout=timeout, ) - except NotImplementedError: - # If platform doesn't implement stateless_init_device_torch_dist_pg, it - # will raise a NotImplementedError. In this case, we fall back to gloo. - return init_gloo_process_group( + else: + from vllm.platforms import current_platform + + pg = current_platform.stateless_init_device_torch_dist_pg( + backend=backend, prefix_store=prefix_store, group_rank=group_rank, group_size=group_size, timeout=timeout, ) + if group_name is not None: + from torch._C._distributed_c10d import _register_process_group + + pg._set_group_name(group_name) + _register_process_group(group_name, pg) + + if return_store: + return pg, store + else: + return pg + def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None: """ diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e4c9a82d2522..e6fd0879613e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -407,6 +407,7 @@ class EngineArgs: data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel all2all_backend: str | None = ParallelConfig.all2all_backend + enable_elastic_ep: bool = ParallelConfig.enable_elastic_ep enable_dbo: bool = ParallelConfig.enable_dbo dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold @@ -815,6 +816,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--all2all-backend", **parallel_kwargs["all2all_backend"] ) parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"]) + parallel_group.add_argument( + "--enable-elastic-ep", **parallel_kwargs["enable_elastic_ep"] + ) parallel_group.add_argument( "--dbo-decode-token-threshold", **parallel_kwargs["dbo_decode_token_threshold"], @@ -1607,6 +1611,7 @@ def create_engine_config( data_parallel_hybrid_lb=self.data_parallel_hybrid_lb, enable_expert_parallel=self.enable_expert_parallel, all2all_backend=self.all2all_backend, + enable_elastic_ep=self.enable_elastic_ep, enable_dbo=self.enable_dbo, dbo_decode_token_threshold=self.dbo_decode_token_threshold, dbo_prefill_token_threshold=self.dbo_prefill_token_threshold, diff --git a/vllm/envs.py b/vllm/envs.py index 56558548d398..da444670f9be 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -176,6 +176,7 @@ "pplx", "deepep_high_throughput", "deepep_low_latency", + "nixl_ep", "allgather_reducescatter", "flashinfer_all2allv", ] = "allgather_reducescatter" @@ -232,6 +233,12 @@ VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_USE_V2_MODEL_RUNNER: bool = False + VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False + VLLM_NIXL_EP_MAX_NUM_RANKS: int = 128 + VLLM_NIXL_EP_ETCD_ENDPOINTS: str | None = None + VLLM_NIXL_EP_UCX_IB_DEVICES: str | None = None + VLLM_NIXL_EP_UCX_TCP_DEVICES: str | None = None + VLLM_NIXL_EP_PLUGIN_DIR: str | None = None def get_default_cache_root(): @@ -1251,6 +1258,7 @@ def get_vllm_port() -> int | None: "pplx", "deepep_high_throughput", "deepep_low_latency", + "nixl_ep", "allgather_reducescatter", "flashinfer_all2allv", ], @@ -1526,6 +1534,27 @@ def get_vllm_port() -> int | None: "VLLM_USE_V2_MODEL_RUNNER": lambda: bool( int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0")) ), + # Whether it is a scale up launch engine for elastic EP, + # Should only be set by EngineCoreClient. + "VLLM_ELASTIC_EP_SCALE_UP_LAUNCH": lambda: bool( + int(os.getenv("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH", "0")) + ), + # NIXL EP environment variables + # These are temporarily registered here for Ray to pass to downstream + # EngineCore actors + "VLLM_NIXL_EP_MAX_NUM_RANKS": lambda: int( + os.getenv("VLLM_NIXL_EP_MAX_NUM_RANKS", "128") + ), + "VLLM_NIXL_EP_ETCD_ENDPOINTS": lambda: os.getenv( + "VLLM_NIXL_EP_ETCD_ENDPOINTS", None + ), + "VLLM_NIXL_EP_UCX_IB_DEVICES": lambda: os.getenv( + "VLLM_NIXL_EP_UCX_IB_DEVICES", None + ), + "VLLM_NIXL_EP_UCX_TCP_DEVICES": lambda: os.getenv( + "VLLM_NIXL_EP_UCX_TCP_DEVICES", None + ), + "VLLM_NIXL_EP_PLUGIN_DIR": lambda: os.getenv("VLLM_NIXL_EP_PLUGIN_DIR", None), } # --8<-- [end:env-vars-definition] diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index 86c50f39f007..3040af721e76 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -16,7 +16,7 @@ FusedMoEPrepareAndFinalize, ) from vllm.platforms import current_platform -from vllm.utils.import_utils import has_deep_ep, has_pplx +from vllm.utils.import_utils import has_deep_ep, has_nixl_ep, has_pplx if current_platform.is_cuda_alike(): if has_pplx(): @@ -30,6 +30,8 @@ DEEPEP_QUANT_BLOCK_SHAPE, DeepEPLLPrepareAndFinalize, ) + if has_nixl_ep(): + from .nixl_ep_prepare_finalize import NixlEPPrepareAndFinalize def maybe_roundup_layer_hidden_size( @@ -61,6 +63,11 @@ def maybe_roundup_layer_hidden_size( hidden_size ) + if moe_parallel_config.use_nixl_ep_kernels: + hidden_size = NixlEPPrepareAndFinalize.maybe_roundup_layer_hidden_size( + hidden_size + ) + return hidden_size @@ -168,4 +175,39 @@ def maybe_make_prepare_finalize( local_expert_global_ids=local_expert_global_ids, ) + elif moe.use_nixl_ep_kernels: + assert quant_config is not None + global_to_physical = physical_to_global = local_expert_global_ids = None + if routing_tables is not None: + ( + global_to_physical, + physical_to_global, + local_expert_global_ids, + ) = routing_tables + all_to_all_args = dict( + max_num_tokens_per_dp_rank=moe.max_num_tokens, + token_hidden_size=moe.hidden_dim, + num_ep_ranks=all2all_manager.world_size, + num_global_experts=moe.num_experts, + num_local_experts=moe.num_experts // all2all_manager.world_size, + ) + handle = all2all_manager.get_handle(all_to_all_args) + + # Note: We may want to use FP8 dispatch just to reduce + # data movement. + use_fp8_dispatch = ( + quant_config.quant_dtype == current_platform.fp8_dtype() + and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE + ) + + prepare_finalize = NixlEPPrepareAndFinalize( + handle, + max_tokens_per_rank=moe.max_num_tokens, + num_dispatchers=all2all_manager.world_size, + use_fp8_dispatch=use_fp8_dispatch, + global_to_physical=global_to_physical, + physical_to_global=physical_to_global, + local_expert_global_ids=local_expert_global_ids, + ) + return prepare_finalize diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 1826fafa8c4f..ffc684f58540 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -719,6 +719,10 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency" + @property + def use_nixl_ep_kernels(self): + return self.use_all2all_kernels and self.all2all_backend == "nixl_ep" + @staticmethod def flatten_tp_across_dp_and_pcp( tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int @@ -926,6 +930,10 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels + @property + def use_nixl_ep_kernels(self): + return self.moe_parallel_config.use_nixl_ep_kernels + @property def use_flashinfer_cutlass_kernels(self): """ diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 0ef3130b2633..8eba77ad893e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -216,10 +216,11 @@ def determine_expert_placement_strategy( if ( moe_parallel_config.use_all2all_kernels and not moe_parallel_config.use_deepep_ll_kernels + and not moe_parallel_config.use_nixl_ep_kernels ): logger.warning( "Round-robin expert placement currently only supports " - "the DeepEP low-latency backend, but '%s' was configured. " + "the DeepEP low-latency or NIXL EP backend, but '%s' was configured. " "Falling back to linear expert placement.", moe_parallel_config.all2all_backend, ) @@ -644,6 +645,7 @@ def _get_quant_method() -> FusedMoEMethodBase: moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) + self.base_quant_method = self.quant_method # Chunked all2all staging tensor self.batched_hidden_states: torch.Tensor | None = None @@ -658,7 +660,7 @@ def maybe_init_modular_kernel(self) -> None: # routing_tables only needed for round-robin expert placement with # DeepEP all2all backend. routing_tables = self._maybe_init_expert_routing_tables() - prepare_finalize = self.quant_method.maybe_make_prepare_finalize( + prepare_finalize = self.base_quant_method.maybe_make_prepare_finalize( routing_tables=routing_tables ) if prepare_finalize is not None: @@ -666,7 +668,7 @@ def maybe_init_modular_kernel(self) -> None: "%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self) ) self.quant_method = FusedMoEModularMethod.make( - self, self.quant_method, prepare_finalize, self.shared_experts + self, self.base_quant_method, prepare_finalize, self.shared_experts ) @property @@ -742,6 +744,7 @@ def use_dp_chunking(self) -> bool: return ( self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels + or self.moe_parallel_config.use_nixl_ep_kernels or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels) ) @@ -758,6 +761,7 @@ def _maybe_init_expert_routing_tables( if ( self.expert_placement_strategy != "round_robin" or not self.use_deepep_ll_kernels + or not self.use_nixl_ep_kernels ): return None diff --git a/vllm/model_executor/layers/fused_moe/nixl_ep_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/nixl_ep_prepare_finalize.py new file mode 100644 index 000000000000..6898b6a7401c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/nixl_ep_prepare_finalize.py @@ -0,0 +1,400 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + +import nixl_ep +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import envs +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate, +) +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input, + normalize_batched_scales_shape, +) +from vllm.v1.worker.ubatching import ( + dbo_current_ubatch_id, + dbo_enabled, + dbo_maybe_run_recv_hook, +) + +logger = init_logger(__name__) + +# NIXL EP kernels quantize dispatch inputs in 128 element chunks. +NIXL_EP_QUANT_BLOCK_SIZE = 128 +NIXL_EP_QUANT_BLOCK_SHAPE = [NIXL_EP_QUANT_BLOCK_SIZE, NIXL_EP_QUANT_BLOCK_SIZE] + + +def dequant_fp8( + expert_x_fp8: torch.Tensor, expert_x_scales: torch.Tensor +) -> torch.Tensor: + """ + Return dequantized tensor in fp32 + """ + assert expert_x_fp8.is_contiguous() + expert_x_scales = expert_x_scales.contiguous() + num_experts = expert_x_fp8.size(0) + + expert_x_fp32 = expert_x_fp8.to(torch.float32).view( + num_experts, -1, NIXL_EP_QUANT_BLOCK_SIZE + ) + expert_x_scales = expert_x_scales.view(num_experts, -1, 1) + return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size()) + + +class NixlEPPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): + """ + Prepare/Finalize using NIXL EP kernels. + """ + + # NIXL EP kernels are compiled only for certain specific hidden sizes. + # NOTE: Keep this list sorted, maybe_roundup_layer_hidden_size depends + # on it. + SUPPORTED_HIDDEN_SIZES = [2048, 2560, 3072, 4096, 5120, 6144, 7168, 8192] + + @staticmethod + def maybe_roundup_layer_hidden_size(hidden_size: int) -> int: + # Round up hidden size to the closest supported hidden size. + _supported_hs = NixlEPPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES + # Check sorted + num_supported_hs = len(_supported_hs) + assert all( + [ + _supported_hs[i] < _supported_hs[i + 1] + for i in range(num_supported_hs - 1) + ] + ) + + for x in _supported_hs: + if x >= hidden_size: + return x + + raise ValueError( + f"Hidden Size {hidden_size} is greater than the " + f"maximum supported hidden size {_supported_hs[-1]}" + ) + + def __init__( + self, + buffer: nixl_ep.Buffer, + max_tokens_per_rank: int, + num_dispatchers: int, + use_fp8_dispatch: bool = False, + global_to_physical: torch.Tensor | None = None, + physical_to_global: torch.Tensor | None = None, + local_expert_global_ids: torch.Tensor | None = None, + ): + super().__init__() + + self.buffer = buffer + self.max_tokens_per_rank = max_tokens_per_rank + self.use_fp8_dispatch = use_fp8_dispatch + # The dispatch function returns a handle that the combine function + # requires. We store the handle here so it is available to the + # combine function. + self.handles: list[tuple | None] = [None, None] + self.num_dispatchers_ = num_dispatchers + + topk_indices_dtype = self.topk_indices_dtype() + + def _maybe_cast(tensor: torch.Tensor | None) -> torch.Tensor | None: + if tensor is None or topk_indices_dtype is None: + return tensor + return tensor.to(dtype=topk_indices_dtype) + + self.global_to_physical = _maybe_cast(global_to_physical) + self.physical_to_global = _maybe_cast(physical_to_global) + self.local_expert_global_ids = _maybe_cast(local_expert_global_ids) + + # We don't have enough information to determine if we should dispatch + # activation scales in a packed ue8m0 format during object construction + # time. This setting is handled by post_init_setup. + self.use_ue8m0_dispatch = False + + def post_init_setup(self, fused_experts: mk.FusedMoEPermuteExpertsUnpermute): + if not fused_experts.supports_packed_ue8m0_act_scales(): + # Early exit. + return + + if self.use_fp8_dispatch: + logger.debug_once( + "Update NixlEPPrepareAndFinalize to do packed ue8m0 scales dispatch." + ) + self.use_ue8m0_dispatch = True + else: + logger.warning_once( + "NixlEPPrepareAndFinalize is setup to dispatch raw/unquantized " + f"activations despite ({fused_experts.__class__.__name__}) being able " + "to support quantized activations.", + scope="local", + ) + + def num_dispatchers(self) -> int: + return self.num_dispatchers_ + + def output_is_reduced(self) -> bool: + return True + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts + + def max_num_tokens_per_rank(self) -> int | None: + return self.max_tokens_per_rank + + def topk_indices_dtype(self) -> torch.dtype | None: + return torch.int64 + + def _map_global_to_physical_ids(self, topk_ids: torch.Tensor) -> torch.Tensor: + if self.global_to_physical is None: + return topk_ids + return self.global_to_physical[topk_ids] + + def _map_local_to_global_ids(self, expert_topk_ids: torch.Tensor) -> torch.Tensor: + if self.local_expert_global_ids is None: + return expert_topk_ids + return self.local_expert_global_ids[expert_topk_ids] + + def _do_quant( + self, + x: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + a1_dtype: torch.dtype, + quant_config: FusedMoEQuantConfig, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + if self.use_fp8_dispatch: + block_k = ( + quant_config.block_shape[1] + if quant_config.block_shape is not None + else None + ) + if block_k == NIXL_EP_QUANT_BLOCK_SIZE: + # NIXL EP kernels did the quantization for us. + x, x_scales = x + return x, x_scales + + # Dequant to get back the tokens in the datatype we dispatched in. + x_fp8, x_scales = x + x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype) + + assert isinstance(x, torch.Tensor) + + num_experts, max_tokens, hidden_dim = x.size() + + x = x.view((-1, hidden_dim)) + q_dtype = quant_config.quant_dtype + + if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm": + logger.info_once( + "Skip quantization when using FlashInfer CUTEDSL(masked_gemm) " + "for ModelOptNvFp4FusedMoE." + ) + q_dtype = None + + x, x_scales = moe_kernel_quantize_input( + x, + quant_config.a1_scale, + q_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + ) + x = x.view((num_experts, -1, hidden_dim)) + + if q_dtype is not None: + assert x_scales is not None + x_scales = normalize_batched_scales_shape(x_scales, num_experts) + + return x, x_scales + + def supports_async(self) -> bool: + return True + + def prepare_async( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> tuple[Callable, mk.ReceiverType]: + hidden_size = a1.size(1) + assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, ( + f"Hidden Size {hidden_size} not in supported list of hidden sizes" + f"{self.SUPPORTED_HIDDEN_SIZES}" + ) + + a2a_idx = dbo_current_ubatch_id() + + if self.use_fp8_dispatch: + assert hidden_size % 128 == 0, ( + "NIXL EP kernels quantize the inputs in blocks of shape 128" + ) + + has_per_token_scales = ( + quant_config.a1_scale.numel() != 1 + if quant_config.a1_scale is not None + else ( + quant_config.a2_scale.numel() != 1 + if quant_config.a2_scale is not None + else False + ) + ) + assert not has_per_token_scales, ( + "NIXL EP kernels don't support dispatching per-token scales" + ) + + if apply_router_weight_on_input: + topk = topk_ids.size(1) + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1" + ) + a1 = a1 * topk_weights.to(a1.dtype) + + # Dispatch - use buffer.dispatch instead of buffer.low_latency_dispatch + dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids) + expert_x, expert_num_tokens, handle, _, hook = self.buffer.dispatch( + a1, + dispatch_topk_ids, + self.max_tokens_per_rank, + num_experts, + use_fp8=self.use_fp8_dispatch, + # round_scale needs to be set to dispatch in ue8m0 + round_scale=self.use_ue8m0_dispatch, + use_ue8m0=self.use_ue8m0_dispatch, + async_finish=False, + return_recv_hook=True, + ) + self.handles[a2a_idx] = handle + + return ( + hook, + lambda: self._receiver( + expert_x, + expert_num_tokens, + quant_config.a1_scale, + a1.dtype, + quant_config, + ), + ) + + def _receiver( + self, + expert_x: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + expert_num_tokens: torch.Tensor, + a1_scale: torch.Tensor | None, + a1_dtype: torch.dtype, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype, quant_config) + + expert_tokens_meta = mk.ExpertTokensMetadata( + expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None + ) + + return expert_x, expert_x_scale, expert_tokens_meta, None, None + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + hook, receiver = self.prepare_async( + a1, + topk_weights, + topk_ids, + num_experts, + expert_map, + apply_router_weight_on_input, + quant_config, + ) + hook() + return receiver() + + def _finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + do_async: bool, + ) -> tuple[Callable, Callable]: + assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate), ( + "Weight application and reduction happens in the combine kernel." + ) + + a2a_idx = dbo_current_ubatch_id() + do_recv_hook = dbo_enabled() or do_async + handle = self.handles[a2a_idx] + assert handle is not None + + combine_topk_weights = topk_weights + if apply_router_weight_on_input: + # weights have already been applied. + combine_topk_weights = torch.ones_like(topk_weights) + + combine_topk_ids = self._map_global_to_physical_ids(topk_ids) + # Use buffer.combine instead of buffer.low_latency_combine + dbo_maybe_run_recv_hook() + _, _, recv_hook = self.buffer.combine( + fused_expert_output, + combine_topk_ids, + combine_topk_weights, + handle, + async_finish=False, + zero_copy=False, + return_recv_hook=do_recv_hook, + out=output, + ) + + return recv_hook, lambda: None + + def finalize_async( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> tuple[Callable, Callable]: + return self._finalize( + output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl, + do_async=True, + ) + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + self._finalize( + output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl, + do_async=False, + ) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index bc241ac692e2..f5eeaa3616a5 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -737,7 +737,11 @@ def _interleave_mxfp4_cutlass_sm90(w): # batched activation format. As self.fused_experts is not # initialized at this point, we resort to checking the MoE config # directly. - is_batched_moe = self.moe.use_pplx_kernels or self.moe.use_deepep_ll_kernels + is_batched_moe = ( + self.moe.use_pplx_kernels + or self.moe.use_deepep_ll_kernels + or self.moe.use_nixl_ep_kernels + ) if is_batched_moe: num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 else: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 4bf9401b6b05..9585f1b6c2d4 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -6,10 +6,13 @@ import os from collections.abc import Callable +from datetime import timedelta from functools import cache, wraps from typing import TYPE_CHECKING, TypeVar import torch +from torch.distributed import PrefixStore, ProcessGroup +from torch.distributed.distributed_c10d import is_nccl_available from typing_extensions import ParamSpec # import custom ops, trigger op registration @@ -442,6 +445,37 @@ def opaque_attention_op(cls) -> bool: def get_static_graph_wrapper_cls(cls) -> str: return "vllm.compilation.cuda_graph.CUDAGraphWrapper" + @classmethod + def stateless_init_device_torch_dist_pg( + cls, + backend: str, + prefix_store: PrefixStore, + group_rank: int, + group_size: int, + timeout: timedelta, + ) -> ProcessGroup: + assert is_nccl_available() + pg: ProcessGroup = ProcessGroup( + prefix_store, + group_rank, + group_size, + ) + from torch.distributed.distributed_c10d import ProcessGroupNCCL + + backend_options = ProcessGroupNCCL.Options() + backend_options._timeout = timeout + + backend_class = ProcessGroupNCCL( + prefix_store, group_rank, group_size, backend_options + ) + backend_type = ProcessGroup.BackendType.NCCL + device = torch.device("cuda") + pg._set_default_backend(backend_type) + backend_class._set_sequence_number_for_group() + + pg._register_backend(device, backend_type, backend_class) + return pg + @classmethod def device_count(cls) -> int: return cuda_device_count_stateless() diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index ccf3446a3a6e..51e1d3d1f447 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -2,10 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os +from datetime import timedelta from functools import cache, lru_cache, wraps from typing import TYPE_CHECKING import torch +from torch.distributed import PrefixStore, ProcessGroup +from torch.distributed.distributed_c10d import is_nccl_available import vllm.envs as envs from vllm.attention.backends.registry import AttentionBackendEnum @@ -479,6 +482,37 @@ def is_navi(cls) -> bool: def get_static_graph_wrapper_cls(cls) -> str: return "vllm.compilation.cuda_graph.CUDAGraphWrapper" + @classmethod + def stateless_init_device_torch_dist_pg( + cls, + backend: str, + prefix_store: PrefixStore, + group_rank: int, + group_size: int, + timeout: timedelta, + ) -> ProcessGroup: + assert is_nccl_available() + pg: ProcessGroup = ProcessGroup( + prefix_store, + group_rank, + group_size, + ) + from torch.distributed.distributed_c10d import ProcessGroupNCCL + + backend_options = ProcessGroupNCCL.Options() + backend_options._timeout = timeout + + backend_class = ProcessGroupNCCL( + prefix_store, group_rank, group_size, backend_options + ) + backend_type = ProcessGroup.BackendType.NCCL + device = torch.device("cuda") + pg._set_default_backend(backend_type) + backend_class._set_sequence_number_for_group() + + pg._register_backend(device, backend_type, backend_class) + return pg + @classmethod def device_count(cls) -> int: return cuda_device_count_stateless() diff --git a/vllm/utils/import_utils.py b/vllm/utils/import_utils.py index ff0f0350fd94..21106641fdff 100644 --- a/vllm/utils/import_utils.py +++ b/vllm/utils/import_utils.py @@ -428,6 +428,11 @@ def has_deep_gemm() -> bool: return _has_module("deep_gemm") +def has_nixl_ep() -> bool: + """Whether the optional `nixl_ep` package is available.""" + return _has_module("nixl_ep") + + def has_triton_kernels() -> bool: """Whether the optional `triton_kernels` package is available.""" is_available = _has_module("triton_kernels") or _has_module( diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index ce2aae77108d..9d9d357375f7 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -203,6 +203,10 @@ class ReconfigureDistributedRequest(msgspec.Struct): new_data_parallel_rank_local: int new_data_parallel_master_ip: str new_data_parallel_master_port: int + new_data_parallel_master_port_list: list[int] + new_stateless_world_group_port_list: list[list[int]] + new_stateless_dp_group_port_list: list[list[int]] + new_stateless_ep_group_port_list: list[list[int]] class ReconfigureRankType(enum.IntEnum): diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 827a2736af28..5ab2dd573eb3 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -480,7 +480,6 @@ def _run_output_handler(self): engine_core = self.engine_core output_processor = self.output_processor log_stats = self.log_stats - logger_manager = self.logger_manager processor = self.processor async def output_handler(): @@ -527,8 +526,10 @@ async def output_handler(): # 4) Logging. # TODO(rob): make into a coroutine and launch it in # background thread once Prometheus overhead is non-trivial. - if logger_manager: - logger_manager.record( + if self.logger_manager: + # NOTE(yongji): we need to use self.logger_manager here + # since it can be reinstantiated during scaling up + self.logger_manager.record( engine_idx=outputs.engine_index, scheduler_stats=outputs.scheduler_stats, iteration_stats=iteration_stats, @@ -824,17 +825,6 @@ async def scale_elastic_ep( new_data_parallel_size, ) return - logger.info( - "Waiting for requests to drain before scaling up to %s engines...", - new_data_parallel_size, - ) - await self.wait_for_requests_to_drain(drain_timeout) - logger.info( - "Requests have been drained, proceeding with scale to %s engines", - new_data_parallel_size, - ) - await self.engine_core.scale_elastic_ep(new_data_parallel_size) - self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size # recreate stat loggers if new_data_parallel_size > old_data_parallel_size and self.log_stats: @@ -847,6 +837,10 @@ async def scale_elastic_ep( engine_idxs=list(range(new_data_parallel_size)), custom_stat_loggers=None, ) + self.logger_manager.log_engine_initialized() + + await self.engine_core.scale_elastic_ep(new_data_parallel_size) + self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size @property def is_running(self) -> bool: diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index 953342cdd5d0..c176fd14e3e0 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -71,6 +71,9 @@ def __init__(self, parallel_config: ParallelConfig): ) local_only_eng = dp_size == parallel_config.data_parallel_size_local + # NOTE(yongji): handling scaling from intra-node to inter-node + if parallel_config.enable_elastic_ep: + local_only_eng = False back_publish_address = get_engine_client_zmq_addr(local_only_eng, host) back_output_address = get_engine_client_zmq_addr(local_only_eng, host) @@ -192,6 +195,7 @@ def process_input_socket( poller = zmq.Poller() poller.register(publish_front, zmq.POLLIN) + poller.register(publish_back, zmq.POLLIN) poller.register(output_back, zmq.POLLIN) last_publish_time = 0 while True: @@ -222,6 +226,22 @@ def process_input_socket( events = dict(events) wave_state_changed = False + if publish_back in events: + buffer = publish_back.recv() + if buffer == b"\x01": + # NOTE(yongji): newly started engine subscribed + # We need to send READY message here instead of receiving + # SCALE_ELASTIC_EP notification from engine core client + # as SCALE_ELASTIC_EP is only sent when + # new engines finished initialization. + # Subscription message, on the other hand, is sent + # by each engine during initialization + publish_back.send(b"READY") + else: + logger.error( + "DP Coordinator receives unexpected message from engines" + ) + if publish_front in events: buffer = publish_front.recv() if buffer in (b"\x01", b"\x00"): @@ -250,7 +270,6 @@ def process_input_socket( # current_wave # we note that 0 is the wave number for the new # engine - engines_running = False logger.info( "DPCoordinator scaled up from %s to %s engines", current_count, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 8657a95b5e6e..17e424394d29 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -16,6 +16,7 @@ import msgspec import zmq +import vllm.envs as envs from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.envs import enable_envs_cache @@ -105,6 +106,9 @@ def __init__( self.available_gpu_memory_for_kv_cache = -1 + if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: + self._eep_scale_up_before_kv_init() + # Setup KV Caches and update CacheConfig after profiling. num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches( vllm_config @@ -219,12 +223,10 @@ def _initialize_kv_caches( has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs) if has_kv_cache: - if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1": - dp_group = getattr(self, "dp_group", None) - assert dp_group is not None - self.available_gpu_memory_for_kv_cache = ( - ParallelConfig.sync_kv_cache_memory_size(dp_group, -1) - ) + if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: + # NOTE(yongji): should already be set + # during _eep_scale_up_before_kv_init + assert self.available_gpu_memory_for_kv_cache > 0 available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len( kv_cache_specs ) @@ -552,6 +554,14 @@ def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, i self.structured_output_manager.grammar_init(req) return req, request.current_wave + def _eep_scale_up_before_kv_init(self): + raise NotImplementedError + + def _eep_send_worker_notification( + self, notification_type: str, vllm_config: VllmConfig | None = None + ): + raise NotImplementedError + class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" @@ -604,6 +614,12 @@ def __init__( and not vllm_config.parallel_config.data_parallel_external_lb ) + self.addresses = addresses + self.process_input_queue_block = True + if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: + self._eep_send_worker_notification( + "NEW_WORKERS_INIT_READY", vllm_config=vllm_config + ) self._init_data_parallel(vllm_config) super().__init__( @@ -873,8 +889,14 @@ def _process_input_queue(self): if logger.isEnabledFor(DEBUG) and self.input_queue.empty(): logger.debug("EngineCore waiting for work.") waited = True - req = self.input_queue.get() - self._handle_client_request(*req) + block = self.process_input_queue_block + try: + req = self.input_queue.get(block=block) + self._handle_client_request(*req) + except queue.Empty: + break + if not block: + break if waited: logger.debug("EngineCore loop active.") @@ -1017,6 +1039,11 @@ def process_input_sockets( for input_socket, _ in poller.poll(): # (RequestType, RequestData) type_frame, *data_frames = input_socket.recv_multipart(copy=False) + # NOTE(yongji): ignore READY message sent by DP coordinator + # that is used to notify newly started engines + if type_frame.buffer == b"READY": + assert input_socket == coord_socket + continue request_type = EngineCoreRequestType(bytes(type_frame.buffer)) # Deserialize the request data. @@ -1119,6 +1146,10 @@ def __init__( self.current_wave = 0 self.last_counts = (0, 0) + from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState + + self.eep_scaling_state: ElasticEPScalingState | None = None + # Initialize the engine. dp_rank = vllm_config.parallel_config.data_parallel_rank super().__init__( @@ -1153,7 +1184,9 @@ def _init_data_parallel(self, vllm_config: VllmConfig): ) self.dp_rank = dp_rank - self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() + self.dp_group, self.dp_store = ( + vllm_config.parallel_config.stateless_init_dp_group(return_store=True) + ) def shutdown(self): super().shutdown() @@ -1209,7 +1242,12 @@ def run_busy_loop(self): # 1) Poll the input queue until there is work to do. self._process_input_queue() - # 2) Step the engine core. + if self.eep_scaling_state is not None: + _ = self.eep_scaling_state.progress() + if self.eep_scaling_state.is_complete(): + self.process_input_queue_block = True + self.eep_scaling_state = None + executed = self._process_engine_step() self._maybe_publish_request_counts() @@ -1259,54 +1297,124 @@ def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: def reinitialize_distributed( self, reconfig_request: ReconfigureDistributedRequest ) -> None: - stateless_destroy_torch_distributed_process_group(self.dp_group) - self.shutdown() - - parallel_config = self.vllm_config.parallel_config - old_dp_size = parallel_config.data_parallel_size - parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size - if reconfig_request.new_data_parallel_rank != -1: - parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank - # local rank specifies device visibility, it should not be changed - assert ( - reconfig_request.new_data_parallel_rank_local - == ReconfigureRankType.KEEP_CURRENT_RANK - ) - parallel_config.data_parallel_master_ip = ( + from copy import deepcopy + + from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState + + new_parallel_config = deepcopy(self.vllm_config.parallel_config) + old_dp_size = new_parallel_config.data_parallel_size + new_parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size + if ( + reconfig_request.new_data_parallel_rank + != ReconfigureRankType.KEEP_CURRENT_RANK + ): + new_parallel_config.data_parallel_rank = ( + reconfig_request.new_data_parallel_rank + ) + new_parallel_config.data_parallel_master_ip = ( reconfig_request.new_data_parallel_master_ip ) - parallel_config.data_parallel_master_port = ( + new_parallel_config.data_parallel_master_port = ( reconfig_request.new_data_parallel_master_port ) - if reconfig_request.new_data_parallel_rank != -2: - self.dp_rank = parallel_config.data_parallel_rank - self.dp_group = parallel_config.stateless_init_dp_group() - reconfig_request.new_data_parallel_master_port = ( - parallel_config.data_parallel_master_port + new_parallel_config._data_parallel_master_port_list = ( + reconfig_request.new_data_parallel_master_port_list ) - self.model_executor.reinitialize_distributed(reconfig_request) - if reconfig_request.new_data_parallel_size > old_dp_size: - assert self.available_gpu_memory_for_kv_cache > 0 - # pass available_gpu_memory_for_kv_cache from existing - # engine-cores to new engine-cores so they can directly - # use it in _initialize_kv_caches() rather than profiling. - ParallelConfig.sync_kv_cache_memory_size( - self.dp_group, self.available_gpu_memory_for_kv_cache - ) - # NOTE(yongji): newly joined workers require dummy_run even - # CUDA graph is not used - self.model_executor.collective_rpc("compile_or_warm_up_model") - if ( + is_scale_down = reconfig_request.new_data_parallel_size < old_dp_size + is_shutdown = ( reconfig_request.new_data_parallel_rank == ReconfigureRankType.SHUTDOWN_CURRENT_RANK - ): - self.shutdown() - logger.info("DPEngineCoreProc %s shutdown", self.dp_rank) + ) + + self.eep_scaling_state = ElasticEPScalingState( + model_executor=self.model_executor, + engine_core=self, + vllm_config=self.vllm_config, + new_parallel_config=new_parallel_config, + worker_type="removing" if is_shutdown else "existing", + scale_type="scale_down" if is_scale_down else "scale_up", + reconfig_request=reconfig_request, + ) + self.process_input_queue_block = False + logger.info( + "[Elastic EP] Received reconfiguration request and starting scaling up/down" + ) + + def _eep_send_worker_notification( + self, notification_type: str, vllm_config: VllmConfig | None = None + ): + """ + Send notifications to EngineCoreClient, which can then forward + the notifications to other engine processes. It is used for: + 1) In scale up: new workers to notify exisiting workers that they are ready; + 2) In scale down: removing workers to notify EngineCoreClient + so EngineCoreClient can release their ray placement groups; + 3) Both scale up/down: to notify EngineCoreClient that exisiting workers + have already switched to the new parallel setup. + """ + if vllm_config is None: + dp_rank = self.vllm_config.parallel_config.data_parallel_rank else: - logger.info( - "Distributed environment reinitialized for DP rank %s", self.dp_rank + dp_rank = vllm_config.parallel_config.data_parallel_rank + notification_data = (notification_type, dp_rank) + outputs = EngineCoreOutputs( + utility_output=UtilityOutput( + call_id=-1, result=UtilityResult(notification_data) + ) + ) + outputs.engine_index = self.engine_index + + if hasattr(self, "output_thread") and self.output_thread.is_alive(): + self.output_queue.put_nowait((0, outputs)) + else: + encoder = MsgpackEncoder() + with ( + zmq.Context() as ctx, + make_zmq_socket( + ctx, self.addresses.outputs[0], zmq.PUSH, linger=4000 + ) as socket, + ): + socket.send_multipart(encoder.encode(outputs)) + + def eep_handle_worker_notification(self, notification_type: str): + """ + Handle notification received from EngineCoreClient (forwarded from new workers). + """ + assert self.eep_scaling_state is not None + self.eep_scaling_state.handle_notification(notification_type) + + def _eep_scale_up_before_kv_init(self): + from vllm.distributed.elastic_ep.elastic_state import ElasticEPScalingState + from vllm.v1.executor.ray_distributed_executor import RayDistributedExecutor + + self.eep_scaling_state = ElasticEPScalingState( + model_executor=self.model_executor, + engine_core=self, + vllm_config=self.vllm_config, + new_parallel_config=self.vllm_config.parallel_config, + worker_type="new", + scale_type="scale_up", + reconfig_request=None, + ) + self.model_executor.collective_rpc("init_device") + kwargs = {} + if isinstance(self.model_executor, RayDistributedExecutor): + kwargs["max_concurrent_workers"] = ( + self.vllm_config.parallel_config.max_parallel_loading_workers ) + self.model_executor.collective_rpc("load_model") + self._eep_send_worker_notification("NEW_WORKERS_WEIGHTS_INIT_READY") + self.model_executor.collective_rpc( + "elastic_ep_execute", args=("receive_weights",) + ) + self.available_gpu_memory_for_kv_cache = ( + ParallelConfig.sync_kv_cache_memory_size(self.dp_group, -1) + ) + self.model_executor.collective_rpc( + "elastic_ep_execute", args=("prepare_new_worker",) + ) + self.process_input_queue_block = False class DPEngineCoreActor(DPEngineCoreProc): diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 9b440505bd9d..9ec67144fc63 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -26,7 +26,6 @@ from vllm.utils.async_utils import in_loop from vllm.utils.network_utils import ( close_sockets, - get_open_port, get_open_zmq_inproc_path, make_zmq_socket, ) @@ -418,6 +417,54 @@ def validate_alive(self, frames: Sequence[zmq.Frame]): raise EngineDeadError() +@dataclass +class ElasticScalingCache: + existing_workers: list[EngineIdentity] + num_new_workers: int + pending_notifications: dict[str, set[int]] + + +def allocate_stateless_group_ports(parallel_config, new_data_parallel_size: int): + """ + Allocate stateless group ports for elastic EP. + """ + from vllm.utils.network_utils import get_open_ports_list + + assert parallel_config.enable_elastic_ep, "Elastic EP must be enabled" + world_size = parallel_config.world_size + new_world_size_across_dp = world_size * new_data_parallel_size + num_world_groups = 1 + num_dp_groups = max(1, new_world_size_across_dp // new_data_parallel_size) + num_ep_groups = max( + 1, + new_world_size_across_dp + // (new_data_parallel_size * parallel_config.tensor_parallel_size), + ) + total_ports_needed = (num_world_groups + num_dp_groups + num_ep_groups) * 3 + 5 + all_ports = get_open_ports_list(total_ports_needed) + new_data_parallel_master_port_list = all_ports[-5:] + all_ports = all_ports[:-5] + new_stateless_world_group_port_list = [ + all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3) + ] + start_idx = num_world_groups * 3 + new_stateless_dp_group_port_list = [ + all_ports[i : i + 3] for i in range(start_idx, start_idx + num_dp_groups * 3, 3) + ] + start_idx += num_dp_groups * 3 + new_stateless_ep_group_port_list = [ + all_ports[i : i + 3] for i in range(start_idx, start_idx + num_ep_groups * 3, 3) + ] + + parallel_config._stateless_world_group_port_list = ( + new_stateless_world_group_port_list + ) + parallel_config._stateless_dp_group_port_list = new_stateless_dp_group_port_list + parallel_config._stateless_ep_group_port_list = new_stateless_ep_group_port_list + parallel_config.data_parallel_master_port = new_data_parallel_master_port_list.pop() + parallel_config._data_parallel_master_port_list = new_data_parallel_master_port_list + + class MPClient(EngineCoreClient): """ MPClient: base client for multi-proc EngineCore. @@ -843,6 +890,10 @@ def _ensure_output_queue_task(self): output_socket = resources.output_socket assert output_socket is not None + notification_callback_handler: ( + Callable[[AsyncMPClient, Sequence[Any]], Any] | None + ) = getattr(self.__class__, "eep_process_worker_notification", None) + async def process_outputs_socket(): try: while True: @@ -850,7 +901,28 @@ async def process_outputs_socket(): resources.validate_alive(frames) outputs: EngineCoreOutputs = decoder.decode(frames) if outputs.utility_output: - _process_utility_output(outputs.utility_output, utility_results) + if ( + outputs.utility_output.call_id == -1 + and notification_callback_handler is not None + ): + # NOTE(yongji): call_id -1 in utility_output is + # reserved for elastic EP worker notifications. + assert _self_ref is not None + _self = _self_ref() + if not _self: + return + if outputs.utility_output.result is None: + continue + notification_data = outputs.utility_output.result.result + assert isinstance(notification_data, Sequence) + assert len(notification_data) == 2 + asyncio.create_task( + notification_callback_handler(_self, notification_data) + ) + else: + _process_utility_output( + outputs.utility_output, utility_results + ) continue if output_handler is not None: @@ -1027,6 +1099,8 @@ def __init__( # Used only by DPLBAsyncMPClient subclass. self.lb_engines: list[list[int]] = [[0, 0] for _ in self.core_engines] + self.eep_scaling_cache: ElasticScalingCache | None = None + self.first_req_sock_addr = get_open_zmq_inproc_path() self.first_req_send_socket = self.resources.first_req_send_socket = ( make_zmq_socket(self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=True) @@ -1072,6 +1146,7 @@ async def run_engine_stats_update_task(): poller.register(socket, zmq.POLLIN) poller.register(first_req_rcv_socket, zmq.POLLIN) + nonlocal count_slice while True: events = await poller.poll() if ( @@ -1091,6 +1166,33 @@ async def run_engine_stats_update_task(): ): # Extract new engine count from the decoded message new_engine_count = decoded[1] + # Update engine_ranks_managed and count_slice + parallel_config = self.vllm_config.parallel_config + dp_size = parallel_config.data_parallel_size + dp_rank = parallel_config.data_parallel_rank + assert dp_rank == 0 + assert dp_size == new_engine_count + assert not ( + parallel_config.data_parallel_hybrid_lb + or parallel_config.data_parallel_external_lb + ) + num_ranks = dp_size + self.engine_ranks_managed = list( + range(dp_rank, dp_rank + num_ranks) + ) + count_slice = slice( + self.engine_ranks_managed[0], + self.engine_ranks_managed[-1] + 1, + ) + if len(self.lb_engines) < new_engine_count: + self.lb_engines = self.lb_engines + [ + [0, 0] + for _ in range( + new_engine_count - len(self.lb_engines) + ) + ] + else: + self.lb_engines = self.lb_engines[:new_engine_count] # Send scale up notification to coordinator scale_msg = msgspec.msgpack.encode( ("SCALE_ELASTIC_EP", new_engine_count) @@ -1233,6 +1335,55 @@ async def process_engine_outputs( for req_id in outputs.finished_requests: self.reqs_in_flight.pop(req_id, None) + @staticmethod + async def eep_process_worker_notification( + self: "DPLBAsyncMPClient", notification_data: tuple[str, int] + ): + cache = self.eep_scaling_cache + notification_type, dp_rank = notification_data + if notification_type == "RECONFIGURE_FINISHED": + from vllm.v1.engine import UtilityResult + + dummy_output = UtilityOutput(call_id=-1, result=UtilityResult(None)) + _process_utility_output(dummy_output, self.utility_results) + return + assert cache is not None + if notification_type not in cache.pending_notifications: + cache.pending_notifications[notification_type] = set() + if dp_rank in cache.pending_notifications[notification_type]: + raise ValueError( + f"Duplicate notification {notification_type} from dp_rank {dp_rank}" + ) + cache.pending_notifications[notification_type].add(dp_rank) + if len(cache.pending_notifications[notification_type]) >= abs( + cache.num_new_workers + ): + if notification_type == "SHUTDOWN_COMPLETE": + assert isinstance(self.resources.engine_manager, CoreEngineActorManager) + assert cache.num_new_workers < 0 + old_dp_size = len(cache.existing_workers) + new_dp_size = old_dp_size + cache.num_new_workers + self.resources.engine_manager.scale_down_elastic_ep( + old_dp_size, new_dp_size + ) + else: + await asyncio.gather( + *[ + self._call_utility_async( + "eep_handle_worker_notification", + notification_type, + engine=engine, + ) + for engine in cache.existing_workers + ] + ) + cache.pending_notifications[notification_type] = set() + if notification_type in [ + "SHUTDOWN_COMPLETE", + "NEW_WORKERS_WEIGHTS_INIT_READY", + ]: + self.eep_scaling_cache = None + async def abort_requests_async(self, request_ids: list[str]) -> None: if not request_ids or self.resources.engine_dead: return @@ -1279,6 +1430,20 @@ async def scale_elastic_ep(self, new_data_parallel_size: int) -> None: cur_data_parallel_size, new_data_parallel_size ) + async def _eep_wait_for_setup_switch_complete(self) -> None: + """ + Wait for workers to switch to the new setup. + """ + # NOTE(yongji): In eep_process_worker_notification(), + # a dummy UtilityOutput with call_id -1 will be set + # when RECONFIGURE_FINISHED notification is received. + # from engine 0. + call_id = -1 + future = asyncio.get_running_loop().create_future() + self.utility_results[call_id] = future + self._ensure_output_queue_task() + await future + async def _scale_up_elastic_ep( self, cur_data_parallel_size: int, new_data_parallel_size: int ) -> None: @@ -1286,38 +1451,56 @@ async def _scale_up_elastic_ep( and reconfiguring existing ones.""" cur_data_parallel_size = len(self.core_engines) - # Phase 1: Send reconfigure messages to all existing engines and wait - # for them to be sent + self.eep_scaling_cache = ElasticScalingCache( + existing_workers=self.core_engines.copy(), + num_new_workers=new_data_parallel_size - cur_data_parallel_size, + pending_notifications=dict(), + ) + + parallel_config = self.vllm_config.parallel_config + allocate_stateless_group_ports(parallel_config, new_data_parallel_size) + + # Phase 1: Send reconfig messages to existing engines reconfig_futures = [] - self.vllm_config.parallel_config.data_parallel_master_port = get_open_port() for engine in self.core_engines: reconfig_request = ReconfigureDistributedRequest( new_data_parallel_size=new_data_parallel_size, new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK, new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK, - new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip, - new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port, + new_data_parallel_master_ip=parallel_config.data_parallel_master_ip, + new_data_parallel_master_port=parallel_config.data_parallel_master_port, + new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list, + new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list, + new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list, + new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list, ) coro = self._call_utility_async( "reinitialize_distributed", reconfig_request, engine=engine ) reconfig_futures.append(asyncio.create_task(coro)) - logger.info("All reconfigure messages sent, starting engine creation") - - # Phase 2: Create new engines now that reconfig messages have been sent - # self.resources.engine_manager is guaranteed to be - # CoreEngineActorManager for RayDPClient + # Phase 2: Create new engines assert isinstance(self.resources.engine_manager, CoreEngineActorManager) - self.resources.engine_manager.scale_up_elastic_ep( - self.vllm_config, new_data_parallel_size + parallel_config.eplb_config.num_redundant_experts = 0 + start_new_worker_future = asyncio.to_thread( + self.resources.engine_manager.scale_up_elastic_ep, + self.vllm_config, + new_data_parallel_size, ) + wait_future = self._eep_wait_for_setup_switch_complete() + + # Phase 3: Wait for new engines to be created + # and reconfig messages to be received + await asyncio.gather(start_new_worker_future, *reconfig_futures) + logger.info("[Elastic EP] Successfully started new engines") # Create new CoreEngine objects for the new engines new_engine_identities = set() for i in range(cur_data_parallel_size, new_data_parallel_size): new_engine = i.to_bytes(2, "little") self.core_engines.append(new_engine) + # NOTE(yongji): we don't update lb_engines here, + # we let run_engine_stats_update_task to update it. new_engine_identities.add(new_engine) # Wait for ready messages from new engines on the input socket @@ -1331,10 +1514,11 @@ async def _scale_up_elastic_ep( identity, _ = sync_input_socket.recv_multipart() new_engine_identities.discard(identity) - # Phase 3: Wait for all existing engines to complete reconfiguration - logger.info("Waiting for existing engines to complete reconfiguration") - await asyncio.gather(*reconfig_futures) - + # NOTE(yongji): Before we schedule any requests on the new workers, + # we should wait for them to switch to the new setup. + await wait_future + # Update the parallel config + self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size # Notify coordinator about scale up through existing # stats_update_task connection self._ensure_stats_update_task() @@ -1343,8 +1527,6 @@ async def _scale_up_elastic_ep( ) await self.first_req_send_socket.send(scale_up_marker) - # Update the parallel config - self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size logger.info( "[Elastic EP] Scale up completed, new data parallel size: %s", new_data_parallel_size, @@ -1357,7 +1539,14 @@ async def _scale_down_elastic_ep( reconfiguring existing engine cores.""" cur_data_parallel_size = len(self.core_engines) - self.vllm_config.parallel_config.data_parallel_master_port = get_open_port() + self.eep_scaling_cache = ElasticScalingCache( + existing_workers=self.core_engines.copy(), + num_new_workers=new_data_parallel_size - cur_data_parallel_size, + pending_notifications=dict(), + ) + + parallel_config = self.vllm_config.parallel_config + allocate_stateless_group_ports(parallel_config, new_data_parallel_size) reconfig_futures = [] for cur_dp_rank, engine in enumerate(self.core_engines): @@ -1365,8 +1554,12 @@ async def _scale_down_elastic_ep( new_data_parallel_size=new_data_parallel_size, new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK, new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK, - new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip, - new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port, + new_data_parallel_master_ip=parallel_config.data_parallel_master_ip, + new_data_parallel_master_port=parallel_config.data_parallel_master_port, + new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list, + new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list, + new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list, + new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list, ) if cur_dp_rank >= new_data_parallel_size: reconfig_request.new_data_parallel_rank = ( @@ -1377,23 +1570,24 @@ async def _scale_down_elastic_ep( ) reconfig_futures.append(asyncio.create_task(coro)) - for _ in range(new_data_parallel_size, cur_data_parallel_size): - self.core_engines.pop() + # NOTE(yongji): Immediately stop sending requests to the removing engines. + self.core_engines = self.core_engines[:new_data_parallel_size] + self.lb_engines = self.lb_engines[:new_data_parallel_size] + wait_future = self._eep_wait_for_setup_switch_complete() await asyncio.gather(*reconfig_futures) - assert isinstance(self.resources.engine_manager, CoreEngineActorManager) - self.resources.engine_manager.scale_down_elastic_ep( - cur_data_parallel_size, new_data_parallel_size - ) - + self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size self._ensure_stats_update_task() scale_down_marker = msgspec.msgpack.encode( ("SCALE_ELASTIC_EP", new_data_parallel_size) ) await self.first_req_send_socket.send(scale_down_marker) - self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size + # NOTE(yongji): Unlike scaling up, + # here we don't actually need to wait for the setup switch to complete. + # We may want to remove it in the future. + await wait_future logger.info( "[Elastic EP] Scale down completed, new data parallel size: %s", new_data_parallel_size, diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index d65cad7af03d..3f54f4bd69bd 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -576,6 +576,8 @@ def add_dp_placement_groups( node_ip = node.node_ip node_id = node.node_id + if device_str not in available_resources[node_id]: + continue available_gpus = int(available_resources[node_id][device_str]) # Get total GPUs on this node from the node's resources @@ -795,6 +797,9 @@ def launch_core_engines( client_local_only = ( offline_mode or local_engines_only or (local_engine_count == dp_size) ) + # NOTE(yongji): handling scaling from intra-node to inter-node + if parallel_config.enable_elastic_ep: + client_local_only = False # Set up input and output addresses. addresses = EngineZmqAddresses( @@ -867,6 +872,10 @@ def launch_core_engines( # will be False. handshake_local_only = offline_mode or local_engine_count == dp_size + # NOTE(yongji): handling scaling from intra-node to inter-node + if parallel_config.enable_elastic_ep: + handshake_local_only = False + handshake_address = get_engine_client_zmq_addr( handshake_local_only, host, parallel_config.data_parallel_rpc_port ) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 7e8ebe25c460..1cb009f2140f 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -549,17 +549,16 @@ def __init__( ) self.async_output_copy_thread.start() - # Initialize device - self.worker.init_device() - - # Set process title and log prefix self.setup_proc_title_and_log_prefix( enable_ep=vllm_config.parallel_config.enable_expert_parallel ) # Load model self._init_message_queues(input_shm_handle, vllm_config) - self.worker.load_model() + is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH + if not is_eep_new_worker: + self.worker.init_device() + self.worker.load_model() # Enable environment variable cache (e.g. assume no more # environment variable overrides after this point) diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index 406eafcd339b..f8d793374cef 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -367,8 +367,10 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): all_kwargs.append(kwargs) self.collective_rpc("init_worker", args=(all_kwargs,)) - self.collective_rpc("init_device") - self.collective_rpc("load_model") + is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH + if not is_eep_new_worker: + self.collective_rpc("init_device") + self.collective_rpc("load_model") for pp_rank in range(self.parallel_config.pipeline_parallel_size): self.pp_tp_workers.append([]) diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index 095d3d1dac21..13d2f2b54e05 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -14,7 +14,6 @@ from vllm.logger import init_logger from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput -from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput from vllm.v1.serial_utils import run_method @@ -43,9 +42,11 @@ def _init_executor(self) -> None: max_workers=1, thread_name_prefix="WorkerAsyncOutput" ) + is_eep_new_worker = envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH self.driver_worker.init_worker(all_kwargs=[kwargs]) - self.driver_worker.init_device() - self.driver_worker.load_model() + if not is_eep_new_worker: + self.driver_worker.init_device() + self.driver_worker.load_model() def _distributed_args(self) -> tuple[str, int, int]: """Return (distributed_init_method, rank, local_rank).""" @@ -119,16 +120,6 @@ def check_health(self) -> None: # it's running. return - def reinitialize_distributed( - self, reconfig_request: ReconfigureDistributedRequest - ) -> None: - self.driver_worker.reinitialize_distributed(reconfig_request) - if ( - reconfig_request.new_data_parallel_rank - == ReconfigureRankType.SHUTDOWN_CURRENT_RANK - ): - self.shutdown() - def shutdown(self) -> None: if worker := self.driver_worker: worker.shutdown() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0ae4eb48acf2..6f0e7236d2f5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -346,6 +346,8 @@ def __init__( self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) self.eplb_state: EplbState | None = None + # NOTE(yongji): flag to temporarily disable EPLB during scaling up/down + self.eep_eplb_suppressed = False """ State of the expert parallelism load balancer. @@ -2302,7 +2304,7 @@ def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None: """ Step for the EPLB (Expert Parallelism Load Balancing) state. """ - if not self.parallel_config.enable_eplb: + if not self.parallel_config.enable_eplb or self.eep_eplb_suppressed: return assert self.eplb_state is not None @@ -2314,6 +2316,23 @@ def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None: log_stats=self.parallel_config.eplb_config.log_balancedness, ) + def setup_eplb_from_mapping( + self, + expanded_physical_to_logical: torch.Tensor, + old_num_physical_experts: int, + ) -> None: + model = self.get_model() + assert is_mixture_of_experts(model) + + self.eplb_state = EplbState.from_mapping( + model=model, + model_config=self.model_config, + device=self.device, + parallel_config=self.parallel_config, + expanded_physical_to_logical=expanded_physical_to_logical, + num_valid_physical_experts=old_num_physical_experts, + ) + def _pool( self, hidden_states: torch.Tensor, @@ -3404,27 +3423,24 @@ def update_config(self, overrides: dict[str, Any]) -> None: new_config = update_config(config, config_overrides) setattr(self, config_name, new_config) - def load_model(self, eep_scale_up: bool = False) -> None: + def load_model(self, dummy_weights: bool = False) -> None: """ Args: - eep_scale_up: the model loading is for elastic EP scale up. + dummy_weights: load dummy weights instead of real weights. """ logger.info_once( "Starting to load model %s...", self.model_config.model, scope="global", ) - global_expert_loads, old_global_expert_indices_per_model, rank_mapping = ( - EplbState.get_eep_state(self.parallel_config) - if eep_scale_up - else (None, None, None) - ) if self.parallel_config.enable_eplb: self.eplb_state = EplbState(self.parallel_config, self.device) eplb_models = 0 with DeviceMemoryProfiler() as m: time_before_load = time.perf_counter() + if dummy_weights: + self.load_config.load_format = "dummy" model_loader = get_model_loader(self.load_config) self.model = model_loader.load_model( vllm_config=self.vllm_config, model_config=self.model_config @@ -3448,25 +3464,15 @@ def load_model(self, eep_scale_up: bool = False) -> None: "EPLB is enabled for drafter model %s.", spec_config.draft_model_config.model, ) - - global_expert_load = ( - global_expert_loads[eplb_models] - if global_expert_loads - else None - ) - old_global_expert_indices = ( - old_global_expert_indices_per_model[eplb_models] - if old_global_expert_indices_per_model - else None + assert not self.parallel_config.enable_elastic_ep, ( + "Elastic EP is not supported with draft model yet." ) + if self.eplb_state is None: self.eplb_state = EplbState(self.parallel_config, self.device) self.eplb_state.add_model( self.drafter.model, spec_config.draft_model_config, - global_expert_load, - old_global_expert_indices, - rank_mapping, ) eplb_models += 1 @@ -3497,11 +3503,12 @@ def load_model(self, eep_scale_up: bool = False) -> None: time_after_load - time_before_load, scope="local", ) - prepare_communication_buffer_for_model(self.model) - if (drafter := getattr(self, "drafter", None)) and ( - drafter_model := getattr(drafter, "model", None) - ): - prepare_communication_buffer_for_model(drafter_model) + if not dummy_weights: + prepare_communication_buffer_for_model(self.model) + if (drafter := getattr(self, "drafter", None)) and ( + drafter_model := getattr(drafter, "model", None) + ): + prepare_communication_buffer_for_model(drafter_model) mm_config = self.model_config.multimodal_config self.is_multimodal_pruning_enabled = ( supports_multimodal_pruning(self.get_model()) @@ -3509,26 +3516,19 @@ def load_model(self, eep_scale_up: bool = False) -> None: and mm_config.is_multimodal_pruning_enabled() ) - if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb: + if ( + is_mixture_of_experts(self.model) + and self.parallel_config.enable_eplb + and not dummy_weights + ): logger.info_once("EPLB is enabled for model %s.", self.model_config.model) - global_expert_load = ( - global_expert_loads[eplb_models] if global_expert_loads else None - ) - old_global_expert_indices = ( - old_global_expert_indices_per_model[eplb_models] - if old_global_expert_indices_per_model - else None - ) assert self.eplb_state is not None self.eplb_state.add_model( self.model, self.model_config, - global_expert_load, - old_global_expert_indices, - rank_mapping, ) if self.eplb_state.is_async: - self.eplb_state.start_async_loop(rank_mapping=rank_mapping) + self.eplb_state.start_async_loop() if ( self.vllm_config.compilation_config.mode diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d0c6091ce2a6..f63da870e21e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -6,11 +6,10 @@ import os from contextlib import AbstractContextManager, nullcontext from types import NoneType -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any import numpy as np import torch -import torch.distributed import torch.nn as nn import vllm.envs as envs @@ -27,14 +26,12 @@ has_kv_transfer_group, ) from vllm.distributed.parallel_state import ( - get_pcp_group, get_pp_group, get_tp_group, ) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed -from vllm.model_executor.models.interfaces import is_mixture_of_experts from vllm.model_executor.warmup.kernel_warmup import kernel_warmup from vllm.platforms import current_platform from vllm.profiler.gpu_profiler import CudaProfilerWrapper, TorchProfilerWrapper @@ -43,7 +40,6 @@ from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_utils import MemorySnapshot, memory_profiling from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput -from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ( AsyncModelRunnerOutput, @@ -78,6 +74,10 @@ def __init__( is_driver_worker=is_driver_worker, ) + from vllm.distributed.elastic_ep.elastic_execute import ElasticEPScalingExecutor + + self.elastic_ep_executor = ElasticEPScalingExecutor(self) + if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils.import_utils import init_cached_hf_modules @@ -258,9 +258,26 @@ def init_device(self): # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool # to hijack tensor allocation. def load_model(self) -> None: - eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1" + dummy_weights = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1" + if dummy_weights: + ( + expanded_physical_to_logical, + num_logical_experts, + old_num_physical_experts, + ) = self.elastic_ep_executor.receive_expert_mapping() + num_physical_experts = expanded_physical_to_logical.shape[1] + self.parallel_config.eplb_config.num_redundant_experts = ( + num_physical_experts - num_logical_experts + ) + with self._maybe_get_memory_pool_context(tag="weights"): - self.model_runner.load_model(eep_scale_up=eep_scale_up) + self.model_runner.load_model(dummy_weights=dummy_weights) + + if dummy_weights: + self.model_runner.setup_eplb_from_mapping( + expanded_physical_to_logical, old_num_physical_experts + ) + self.model_runner.eep_eplb_suppressed = True def update_config(self, overrides: dict[str, Any]) -> None: self.model_runner.update_config(overrides) @@ -634,223 +651,6 @@ def check_health(self) -> None: # worker will always be healthy as long as it's running. return - def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None: - from vllm.distributed.parallel_state import get_ep_group - - if get_ep_group().rank == 0: - logger.info( - "[Elastic EP] Starting expert resharding before scaling down..." - ) - rank_mapping = { - old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1 - for old_ep_rank in range(old_ep_size) - } - assert self.model_runner.eplb_state is not None - self.model_runner.eplb_state.rearrange( - execute_shuffle=True, - global_expert_loads=None, - rank_mapping=rank_mapping, - ) - torch.cuda.synchronize() - if get_ep_group().rank == 0: - logger.info("[Elastic EP] Expert resharding completed!") - - def _eplb_after_scale_up( - self, - old_ep_size: int, - new_ep_size: int, - global_expert_loads: list[torch.Tensor] | None, - ) -> None: - from vllm.distributed.parallel_state import get_ep_group - - if get_ep_group().rank == 0: - logger.info("[Elastic EP] Starting expert resharding after scaling up...") - rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)} - assert self.model_runner.eplb_state is not None - self.model_runner.eplb_state.rearrange( - execute_shuffle=True, - global_expert_loads=global_expert_loads, - rank_mapping=rank_mapping, - ) - if get_ep_group().rank == 0: - logger.info("[Elastic EP] Expert resharding completed!") - - def _reconfigure_parallel_config( - self, reconfig_request: ReconfigureDistributedRequest - ) -> None: - """ - Update parallel config with provided reconfig_request - """ - parallel_config = self.vllm_config.parallel_config - parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size - if ( - reconfig_request.new_data_parallel_rank - != ReconfigureRankType.KEEP_CURRENT_RANK - ): - parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank - if ( - reconfig_request.new_data_parallel_rank_local - != ReconfigureRankType.KEEP_CURRENT_RANK - ): - parallel_config.data_parallel_rank_local = ( - reconfig_request.new_data_parallel_rank_local - ) - parallel_config.data_parallel_master_ip = ( - reconfig_request.new_data_parallel_master_ip - ) - parallel_config.data_parallel_master_port = ( - reconfig_request.new_data_parallel_master_port - ) - - def _reconfigure_moe( - self, old_ep_size: int, new_ep_size: int - ) -> list[torch.Tensor] | None: - """ - Reconfigure MoE modules with provided reconfig_request - - Return the global expert load if new_ep_size > old_ep_size, - otherwise None - """ - from vllm.distributed.parallel_state import ( - get_dp_group, - get_ep_group, - prepare_communication_buffer_for_model, - ) - from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, - FusedMoEParallelConfig, - ) - - parallel_config = self.vllm_config.parallel_config - - def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]: - return [ - module - for module in model.modules() - if ( - module.__class__.__name__ == "FusedMoE" - or module.__class__.__name__ == "SharedFusedMoE" - ) - ] - - def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int): - assert all( - module.moe_config.num_local_experts == num_local_experts - for module in moe_modules - ), "All MoE modules must have the same number of experts" - for module in moe_modules: - module.moe_config.num_experts = num_local_experts * new_ep_size - module.global_num_experts = module.moe_config.num_experts - module.moe_parallel_config = FusedMoEParallelConfig.make( - tp_size_=get_tp_group().world_size, - pcp_size_=get_pcp_group().world_size, - dp_size_=get_dp_group().world_size, - vllm_parallel_config=parallel_config, - ) - module.moe_config.moe_parallel_config = module.moe_parallel_config - return moe_modules - - model_moe_modules = get_moe_modules(self.model_runner.model) - num_local_experts = model_moe_modules[0].moe_config.num_local_experts - - update_moe_modules(model_moe_modules, num_local_experts) - drafter_model = None - if hasattr(self.model_runner, "drafter") and hasattr( - self.model_runner.drafter, "model" - ): - drafter_model = self.model_runner.drafter.model - if drafter_model is not None and is_mixture_of_experts(drafter_model): - drafter_moe_modules = get_moe_modules(drafter_model) - # Check if drafter and model have matching configs - assert ( - drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts - ), "Drafter and model configs should be the same" - update_moe_modules(drafter_moe_modules, num_local_experts) - - if new_ep_size < old_ep_size: - num_local_physical_experts = num_local_experts - assert self.model_runner.eplb_state is not None - new_physical_experts = ( - self.model_runner.eplb_state.physical_to_logical_map.shape[1] # type: ignore[attr-defined] - ) - parallel_config.eplb_config.num_redundant_experts = ( - new_physical_experts - - self.model_runner.eplb_state.logical_replica_count.shape[1] # type: ignore[attr-defined] - ) - global_expert_loads = None - else: - num_local_physical_experts_tensor = torch.tensor( - [num_local_experts], dtype=torch.int32, device="cpu" - ) - torch.distributed.broadcast( - num_local_physical_experts_tensor, - group=get_ep_group().cpu_group, - group_src=0, - ) - num_local_physical_experts = int(num_local_physical_experts_tensor.item()) - new_physical_experts = num_local_physical_experts * new_ep_size - assert self.model_runner.eplb_state is not None - global_expert_loads_any = self.model_runner.eplb_state.rearrange( - execute_shuffle=False - ) - global_expert_loads = cast(list[torch.Tensor], global_expert_loads_any) - parallel_config.eplb_config.num_redundant_experts = ( - new_physical_experts - global_expert_loads[0].shape[1] - ) - prepare_communication_buffer_for_model(self.model_runner.model) - if drafter_model is not None: - prepare_communication_buffer_for_model(drafter_model) - self.model_runner.model.update_physical_experts_metadata( - num_physical_experts=new_physical_experts, - num_local_physical_experts=num_local_physical_experts, - ) - return global_expert_loads - - def reinitialize_distributed( - self, reconfig_request: ReconfigureDistributedRequest - ) -> None: - from vllm.config import set_current_vllm_config - from vllm.distributed.parallel_state import ( - cleanup_dist_env_and_memory, - get_ep_group, - ) - - old_ep_size = get_ep_group().world_size - old_ep_rank = get_ep_group().rank - new_ep_size = ( - reconfig_request.new_data_parallel_size - * get_tp_group().world_size - * get_pp_group().world_size - ) - if new_ep_size < old_ep_size: - self._eplb_before_scale_down(old_ep_size, new_ep_size) - - cleanup_dist_env_and_memory() - - if ( - reconfig_request.new_data_parallel_rank - == ReconfigureRankType.SHUTDOWN_CURRENT_RANK - ): - assert old_ep_rank >= new_ep_size - # shutdown - return - - self._reconfigure_parallel_config(reconfig_request) - - with set_current_vllm_config(self.vllm_config): - init_worker_distributed_environment( - self.vllm_config, - self.rank, - self.distributed_init_method, - self.local_rank, - ) - - global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size) - - if new_ep_size > old_ep_size: - assert global_expert_loads is not None - self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads) - def save_sharded_state( self, path: str, @@ -880,6 +680,9 @@ def shutdown(self) -> None: if self.profiler is not None: self.profiler.shutdown() + def elastic_ep_execute(self, execute_method: str, *args, **kwargs): + return self.elastic_ep_executor.execute(execute_method, *args, **kwargs) + def init_worker_distributed_environment( vllm_config: VllmConfig,