From 6a50853971dbefb1ef4b22ab5bfb93bb57fc4b1f Mon Sep 17 00:00:00 2001 From: nifeng <1542305589@qq.com> Date: Sat, 28 Mar 2026 16:29:35 +0800 Subject: [PATCH 1/4] [Feat.][EEP] vLLM Ascend adaptation for Elastic EP Milestone2 Signed-off-by: nifeng <1542305589@qq.com> --- vllm_ascend/compilation/acl_graph.py | 23 +- .../device_communicators/npu_communicator.py | 33 +- .../device_communicators/pyhccl.py | 52 +++ .../device_communicators/pyhccl_wrapper.py | 54 +++ .../distributed/elastic_ep/__init__.py | 0 .../distributed/elastic_ep/elastic_execute.py | 342 ++++++++++++++++++ .../distributed/elastic_ep/standby_state.py | 94 +++++ vllm_ascend/distributed/parallel_state.py | 101 +++++- vllm_ascend/eplb/adaptor/vllm_adaptor.py | 5 + .../eplb/core/eplb_device_transfer_loader.py | 26 +- vllm_ascend/eplb/core/eplb_worker.py | 67 +++- .../eplb/core/policy/policy_abstract.py | 3 + .../eplb/core/policy/policy_default_eplb.py | 107 +++++- vllm_ascend/eplb/eplb_updator.py | 45 ++- vllm_ascend/ops/fused_moe/fused_moe.py | 4 +- vllm_ascend/ops/fused_moe/token_dispatcher.py | 4 +- vllm_ascend/platform.py | 75 ++++ .../quantization/methods/w8a8_dynamic.py | 18 +- vllm_ascend/worker/block_table.py | 2 +- vllm_ascend/worker/model_runner_v1.py | 22 +- vllm_ascend/worker/worker.py | 13 +- 21 files changed, 1003 insertions(+), 87 deletions(-) create mode 100644 vllm_ascend/distributed/elastic_ep/__init__.py create mode 100644 vllm_ascend/distributed/elastic_ep/elastic_execute.py create mode 100644 vllm_ascend/distributed/elastic_ep/standby_state.py diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 1ceb3434616..6b5ad1e51a4 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses +import weakref from collections.abc import Callable from contextlib import ExitStack from dataclasses import dataclass -from typing import Any +from typing import Any, ClassVar from unittest.mock import patch import torch @@ -60,6 +61,16 @@ class ACLGraphWrapper: guaranteed when VLLM_LOGGING_LEVEL == "DEBUG". """ + all_instances: ClassVar[weakref.WeakSet["ACLGraphWrapper"]] = weakref.WeakSet() + graph_pool: ClassVar[tuple[int, int]] = current_platform.get_global_graph_pool() + + @classmethod + def clear_all_graphs(cls) -> None: + """Clear all graphs from all ACLGraphWrapper instances.""" + for instance in list(cls.all_instances): + instance.clear_graphs() + cls.graph_pool = (cls.graph_pool[0], cls.graph_pool[1] + 1) + def __init__( self, runnable: Callable, @@ -79,7 +90,6 @@ def __init__( # assert runtime_mode is not NONE(no aclgraph), otherwise, we don't # need to initialize a ACLGraphWrapper. assert self.runtime_mode != CUDAGraphMode.NONE - self.graph_pool = current_platform.get_global_graph_pool() if cudagraph_options is None: cudagraph_options = CUDAGraphOptions() @@ -102,6 +112,13 @@ def unwrap(self) -> Callable: # in case we need to access the original runnable. return self.runnable + def clear_graphs(self) -> None: + for batch_descriptor in self.concrete_aclgraph_entries: + entry = self.concrete_aclgraph_entries[batch_descriptor] + entry.aclgraph.reset() + del entry.aclgraph, entry.batch_descriptor, entry.output, entry.input_addresses, entry + self.concrete_aclgraph_entries.clear() + def __call__(self, *args, **kwargs): forward_context = get_forward_context() batch_descriptor = forward_context.batch_descriptor @@ -149,7 +166,7 @@ def __call__(self, *args, **kwargs): # mind-exploding: carefully manage the reference and memory. forward_context.capturing = True - with torch.npu.graph(aclgraph, pool=self.graph_pool): + with torch.npu.graph(aclgraph, pool=ACLGraphWrapper.graph_pool): # `output` is managed by pytorch's aclgraph pool output = self.runnable(*args, **kwargs) if self.aclgraph_options.weak_ref_output: diff --git a/vllm_ascend/distributed/device_communicators/npu_communicator.py b/vllm_ascend/distributed/device_communicators/npu_communicator.py index 6950c87af61..090a256caf5 100644 --- a/vllm_ascend/distributed/device_communicators/npu_communicator.py +++ b/vllm_ascend/distributed/device_communicators/npu_communicator.py @@ -18,6 +18,7 @@ import torch import torch.distributed as dist from vllm.distributed.device_communicators.base_device_communicator import DeviceCommunicatorBase +from vllm.distributed.utils import StatelessProcessGroup class NPUCommunicator(DeviceCommunicatorBase): @@ -27,12 +28,30 @@ def __init__( device: torch.device | None = None, device_group: dist.ProcessGroup | None = None, unique_name: str = "", + global_ranks: list[int] | None = None, + global_world_size: int | None = None, + tcp_store_group: StatelessProcessGroup | None = None, ): - super().__init__(cpu_group, device, device_group, unique_name) + super().__init__( + cpu_group, + device, + device_group, + unique_name, + global_ranks, + global_world_size, + ) # TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator # init device according to rank self.device = torch.npu.current_device() + from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator + + self.pyhccl_comm: PyHcclCommunicator | None = None + if self.world_size > 1: + self.pyhccl_comm = PyHcclCommunicator( + group=self.cpu_group if tcp_store_group is None else tcp_store_group, device=self.device + ) + def all_to_all( self, input_: torch.Tensor, @@ -62,3 +81,15 @@ def all_to_all( dist.all_to_all(output_list, input_list, group=self.device_group) output_tensor = torch.cat(output_list, dim=gather_dim).contiguous() return output_tensor + + def destroy(self): + if self.pyhccl_comm is not None: + self.pyhccl_comm.destroy() + self.pyhccl_comm = None + + def batch_isend_irecv(self, p2p_ops: list): + pyhccl_comm = self.pyhccl_comm + if pyhccl_comm is not None and not pyhccl_comm.disabled: + pyhccl_comm.batch_isend_irecv(p2p_ops) + else: + raise ValueError("No PyHccl communicator found") diff --git a/vllm_ascend/distributed/device_communicators/pyhccl.py b/vllm_ascend/distributed/device_communicators/pyhccl.py index 220c48f304c..9c0c797e9b2 100644 --- a/vllm_ascend/distributed/device_communicators/pyhccl.py +++ b/vllm_ascend/distributed/device_communicators/pyhccl.py @@ -127,6 +127,13 @@ def __init__( stream.synchronize() del data + def destroy(self): + if self.available and not self.disabled: + with torch.accelerator.device_index(self.device.index): + self.hccl.hcclCommDestroy(self.comm) + self.available = False + self.disabled = True + def all_reduce(self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None) -> torch.Tensor: if self.disabled: return None @@ -152,6 +159,40 @@ def all_reduce(self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, strea ) return out_tensor + def send(self, tensor: torch.Tensor, dst: int, stream=None): + if self.disabled: + return None + assert tensor.device == self.device, ( + f"this hccl communicator is created to work on {self.device}, but the tensor is on {tensor.device}" + ) + if stream is None: + stream = current_stream() + self.hccl.hcclSend( + buffer_type(tensor.data_ptr()), + tensor.numel(), + hcclDataTypeEnum.from_torch(tensor.dtype), + dst, + self.comm, + aclrtStream_t(stream.npu_stream), + ) + + def recv(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return None + assert tensor.device == self.device, ( + f"this hccl communicator is created to work on {self.device}, but the tensor is on {tensor.device}" + ) + if stream is None: + stream = current_stream() + self.hccl.hcclRecv( + buffer_type(tensor.data_ptr()), + tensor.numel(), + hcclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + aclrtStream_t(stream.npu_stream), + ) + def broadcast(self, tensor: torch.Tensor, src: int, stream=None): if self.disabled: return @@ -172,3 +213,14 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None): self.comm, aclrtStream_t(stream.npu_stream), ) + + def batch_isend_irecv(self, p2p_ops: list, stream=None): + if self.disabled: + return + if stream is None: + stream = current_stream() + for op in p2p_ops: + if op.op is torch.distributed.isend: + self.send(op.tensor, op.group_peer, stream) + elif op.op is torch.distributed.irecv: + self.recv(op.tensor, op.group_peer, stream) diff --git a/vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py b/vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py index e18992cf9ba..042713acd0b 100644 --- a/vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py +++ b/vllm_ascend/distributed/device_communicators/pyhccl_wrapper.py @@ -146,6 +146,38 @@ class HCCLLibrary: aclrtStream_t, ], ), + # HcclResult HcclSend( + # void *buf, uint64_t count, + # HcclDataType dataType, uint32_t root, + # HcclComm comm, aclrtStream steam); + Function( + "HcclSend", + hcclResult_t, + [ + buffer_type, + ctypes.c_size_t, + hcclDataType_t, + ctypes.c_int, + hcclComm_t, + aclrtStream_t, + ], + ), + # HcclResult HcclRecv( + # void *buf, uint64_t count, + # HcclDataType dataType, uint32_t root, + # HcclComm comm, aclrtStream steam); + Function( + "HcclRecv", + hcclResult_t, + [ + buffer_type, + ctypes.c_size_t, + hcclDataType_t, + ctypes.c_int, + hcclComm_t, + aclrtStream_t, + ], + ), # HcclResult HcclBroadcast( # void *buf, uint64_t count, # HcclDataType dataType, uint32_t root, @@ -243,6 +275,28 @@ def hcclAllReduce( # by ctypes automatically self.HCCL_CHECK(self._funcs["HcclAllReduce"](sendbuff, recvbuff, count, datatype, op, comm, stream)) + def hcclSend( + self, + sendbuff: buffer_type, + count: int, + datatype: int, + dest: int, + comm: hcclComm_t, + stream: aclrtStream_t, + ) -> None: + self.HCCL_CHECK(self._funcs["HcclSend"](sendbuff, count, datatype, dest, comm, stream)) + + def hcclRecv( + self, + sendbuff: buffer_type, + count: int, + datatype: int, + dest: int, + comm: hcclComm_t, + stream: aclrtStream_t, + ) -> None: + self.HCCL_CHECK(self._funcs["HcclRecv"](sendbuff, count, datatype, dest, comm, stream)) + def hcclBroadcast( self, buf: buffer_type, count: int, datatype: int, root: int, comm: hcclComm_t, stream: aclrtStream_t ) -> None: diff --git a/vllm_ascend/distributed/elastic_ep/__init__.py b/vllm_ascend/distributed/elastic_ep/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_ascend/distributed/elastic_ep/elastic_execute.py b/vllm_ascend/distributed/elastic_ep/elastic_execute.py new file mode 100644 index 00000000000..6aef9e1d1dc --- /dev/null +++ b/vllm_ascend/distributed/elastic_ep/elastic_execute.py @@ -0,0 +1,342 @@ +import copy +import gc + +import numpy as np +import torch +import torch_npu +from vllm.compilation.counter import compilation_counter +from vllm.compilation.wrapper import reset_compile_wrapper +from vllm.config import ( + CompilationMode, + set_current_vllm_config, +) +from vllm.distributed import ( + get_dp_group, + get_ep_group, + get_pcp_group, + get_tp_group, +) +from vllm.distributed.elastic_ep.elastic_execute import ElasticEPScalingExecutor +from vllm.distributed.elastic_ep.standby_state import ( + create_standby_groups, + get_standby_dp_group, + pop_standby_groups, +) +from vllm.distributed.parallel_state import _replace_active_groups +from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator +from vllm.logger import logger +from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig +from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType +from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper +from vllm.v1.worker.workspace import lock_workspace, unlock_workspace + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.compilation.acl_graph import ACLGraphWrapper +from vllm_ascend.distributed.elastic_ep.standby_state import ( + create_ascend_standby_groups, + pop_ascend_standby_groups, +) +from vllm_ascend.distributed.parallel_state import ( + _replace_ascend_active_groups, + get_dynamic_eplb_group, + get_mc2_group, +) +from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method +from vllm_ascend.quantization.methods.w8a8_dynamic import AscendW8A8DynamicFusedMoEMethod + + +def broadcast_expert_mapping( + expert_maps: torch.Tensor | None, + group: StatelessGroupCoordinator, + src_rank: int = 0, +): + if group.rank_in_group == src_rank: + assert expert_maps is not None + shape_tensor = torch.tensor(list(expert_maps.shape), dtype=torch.int64, device="cpu") + else: + shape_tensor = torch.empty(3, dtype=torch.int64, device="cpu") + + shape_tensor = group.tcp_store_group.broadcast(shape_tensor, src_rank) + + if group.rank_in_group != src_rank: + expert_maps = torch.empty( + tuple(shape_tensor.tolist()), + dtype=torch.int64, + device="cpu", + ) + + assert expert_maps is not None + expert_maps = group.tcp_store_group.broadcast(expert_maps, src_rank) + + return expert_maps + + +class AscendElasticEPScalingExecutor(ElasticEPScalingExecutor): + def __init__(self, worker): + super().__init__(worker) + + def load_model(self) -> None: + ( + expert_maps, + num_local_experts, + num_logical_experts, + ) = self.worker.elastic_ep_executor.receive_expert_mapping() + dp_size = self.worker.parallel_config.data_parallel_size + tp_size = self.worker.parallel_config.tensor_parallel_size + pcp_size = self.worker.parallel_config.prefill_context_parallel_size + ep_size = dp_size * tp_size * pcp_size + get_ascend_config().eplb_config.num_redundant_experts = ep_size * num_local_experts - num_logical_experts + if get_ascend_config().eplb_config.dynamic_eplb: + self.worker.model_runner.shared_dict["expert_maps"] = expert_maps + self.worker.model_runner.shared_dict["old_ep_size"] = expert_maps.shape[1] + self.worker.load_model(load_dummy_weights=True) + + def create_standby_groups(self, reconfig_request: ReconfigureDistributedRequest) -> None: + self.reconfig_request = reconfig_request + new_dp_size = reconfig_request.new_data_parallel_size + world_size = self.worker.vllm_config.parallel_config.world_size + new_world_size_across_dp = world_size * new_dp_size + updated_config = copy.copy(self.worker.vllm_config) + updated_config.parallel_config = copy.deepcopy(self.worker.vllm_config.parallel_config) + updated_config.parallel_config.data_parallel_size = new_dp_size + with set_current_vllm_config(updated_config): + create_standby_groups( + new_dp_size=new_dp_size, + new_world_size_across_dp=new_world_size_across_dp, + master_ip=reconfig_request.new_data_parallel_master_ip, + coord_store_port=reconfig_request.coord_store_port, + enable_eplb=updated_config.parallel_config.enable_eplb, + ) + create_ascend_standby_groups( + new_dp_size=new_dp_size, + new_world_size_across_dp=new_world_size_across_dp, + master_ip=reconfig_request.new_data_parallel_master_ip, + coord_store_port=reconfig_request.coord_store_port, + ) + + def broadcast_expert_mapping(self): + standby_dp_group = get_standby_dp_group() + assert standby_dp_group is not None + expert_maps = self.worker.model_runner.shared_dict["expert_maps"] + broadcast_expert_mapping( + expert_maps=expert_maps, + group=standby_dp_group, + src_rank=0, + ) + + def _release_acl_graphs(self) -> None: + if isinstance(self.worker.model_runner.model, UBatchWrapper): + raise RuntimeError("DBO is not yet supported in elastic EP") + + ACLGraphWrapper.clear_all_graphs() + + torch.compiler.reset() + with set_current_vllm_config(self.worker.vllm_config): + reset_compile_wrapper(self.worker.model_runner.get_model()) + + gc.collect() + torch.npu.synchronize() + torch.npu.empty_cache() + + def switch_and_remove(self) -> None: + self._release_acl_graphs() + _replace_active_groups(world=None, dp=None, ep=None, eplb=None, node_count=None) + _replace_ascend_active_groups(mc2=None, dynamic_eplb=None, fc3_quant_x=None) + + def switch_and_prepare(self) -> None: + old_ep_size = get_ep_group().world_size + self.worker.model_runner.shared_dict["old_ep_size"] = old_ep_size + + self._release_acl_graphs() + _replace_active_groups(**pop_standby_groups()) + _replace_ascend_active_groups(**pop_ascend_standby_groups()) + + parallel_config = self.worker.vllm_config.parallel_config + reconfig_request = self.reconfig_request + assert reconfig_request is not None + new_dp_size = reconfig_request.new_data_parallel_size + new_ep_size = get_ep_group().world_size + + parallel_config.data_parallel_size = new_dp_size + + if reconfig_request.new_data_parallel_rank != ReconfigureRankType.KEEP_CURRENT_RANK: + parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank + if reconfig_request.new_data_parallel_rank_local != ReconfigureRankType.KEEP_CURRENT_RANK: + parallel_config.data_parallel_rank_local = reconfig_request.new_data_parallel_rank_local + parallel_config.data_parallel_master_ip = reconfig_request.new_data_parallel_master_ip + parallel_config.data_parallel_master_port = reconfig_request.new_data_parallel_master_port + self.worker.model_runner.dp_size = new_dp_size + self.worker.model_runner.eplb_updator.comm_group = get_dynamic_eplb_group() + self.worker.model_runner.eplb_updator.world_size = get_dynamic_eplb_group().world_size + self.worker.model_runner.eplb_updator.cur_iterations = 0 + self.worker.model_runner.eplb_loader.comm_group = get_dynamic_eplb_group() + + # Reconfigure MoE modules with new EP size + moe_modules = [ + module + for module in self.worker.model_runner.model.modules() + if (module.__class__.__name__ == "AscendFusedMoE" or module.__class__.__name__ == "AscendSharedFusedMoE") + ] + num_local_experts = moe_modules[0].moe_config.num_local_experts + assert all(module.moe_config.num_local_experts == num_local_experts for module in moe_modules), ( + "All MoE modules must have the same number of experts" + ) + for module in moe_modules: + # module.local_num_experts = module.w2_weight.shape[0] + num_logical_experts = self.worker.model_runner.shared_dict["expert_maps"].shape[-1] + module.global_redundant_expert_num = module.local_num_experts * new_ep_size - num_logical_experts + module.moe_config.num_experts = num_local_experts * new_ep_size + module.global_num_experts = module.moe_config.num_experts + tp_size = get_tp_group().world_size + is_sequence_parallel = parallel_config.use_sequence_parallel_moe + sp_size = tp_size if is_sequence_parallel else 1 + module.moe_parallel_config = FusedMoEParallelConfig.make( + tp_size_=tp_size, + pcp_size_=get_pcp_group().world_size, + dp_size_=get_dp_group().world_size, + sp_size_=sp_size, + vllm_parallel_config=parallel_config, + ) + module.moe_config.moe_parallel_config = module.moe_parallel_config + + module.moe_config.tp_group = get_tp_group() + module.moe_config.dp_group = get_dp_group() + module.moe_config.ep_group = get_ep_group() + module.moe_config.mc2_group = get_mc2_group() + + with set_current_vllm_config(self.worker.vllm_config): + if hasattr(module.quant_method, "quant_method") and isinstance( + module.quant_method.quant_method, AscendW8A8DynamicFusedMoEMethod + ): + module.quant_method.quant_method = AscendW8A8DynamicFusedMoEMethod() + setup_moe_comm_method(module.moe_config) + + if self.worker.vllm_config.compilation_config.mode == CompilationMode.STOCK_TORCH_COMPILE: + # NOTE(yongji): when using stock torch.compile, + # torch.compile is triggered during GPUModelRunner's load_model() + # TODO(yongji):check do we need to re-trigger torch.compile here? + # any changes to the tensor shapes in execution should already + # be handled internally by torch.compile. + backend = self.worker.vllm_config.compilation_config.init_backend(self.worker.vllm_config) + compilation_counter.stock_torch_compile_count += 1 + self.worker.model_runner.model.compile(fullgraph=True, backend=backend) + + multi_block_table = self.worker.model_runner.input_batch.block_table + saved_block_tables: list[tuple[torch.Tensor, torch.Tensor]] = [] + for bt in multi_block_table.block_tables: + saved_block_tables.append((bt.block_table.gpu.clone(), bt.block_table.cpu.clone())) + multi_block_table.clear() + + unlock_workspace() + self.worker.compile_or_warm_up_model() + lock_workspace() + + for bt, (saved_gpu, saved_cpu) in zip(multi_block_table.block_tables, saved_block_tables): + bt.block_table.gpu.copy_(saved_gpu) + bt.block_table.cpu.copy_(saved_cpu) + + def _perform_eplb_reshuffle(self): + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Starting expert resharding...") + eplb_loader = self.worker.model_runner.eplb_loader + eplb_adaptor = self.worker.model_runner.eplb_adaptor + eplb_updator = self.worker.model_runner.eplb_updator + + eplb_updator.compute_and_set_moe_load() + # Wake up the EPLB worker to retrieve expert placement update information + eplb_updator.wakeup_eplb_worker() + # Retrieve the blocking update queue containing expert resharding information + eplb_updator.update_info_all = eplb_updator.eplb_process.block_update_q.get() + # Process each layer's expert redistribution information + while eplb_updator.update_info_all: + (expert_send_info, expert_recv_info, updated_expert_map, log2phy_map, layer_id) = ( + eplb_updator.update_info_all.pop(0) + ) + # Convert logical to physical expert mapping to tensor for this rank + log2phy_map_this_rank = torch.from_numpy(np.array(log2phy_map)) + eplb_loader.set_log2phy_map(log2phy_map_this_rank) + # Convert updated expert mapping to tensor for this rank + updated_expert_map_this_rank = torch.from_numpy(np.array(updated_expert_map)) + # Get global expert map for this layer from shared dictionary + # updated_global_expert_map_this_rank = self.worker.model_runner.shared_dict["expert_maps"][layer_id] + # Generate device-to-device transfer tasks for expert weights + eplb_loader.generate_expert_d2d_transfer_task( + expert_send_info, + expert_recv_info, + updated_expert_map_this_rank, + layer_id + eplb_adaptor.num_dense_layers, + ) + # Execute asynchronous expert weight transfer + reqs = [] + eplb_loader.asyn_expert_weight_transfer(reqs) + # Update expert mapping and apply transferred weights + eplb_loader.update_expert_map_and_weight(reqs) + + # Clear all MoE load statistics after resharding + eplb_adaptor.model.clear_all_moe_loads() + # Reset iteration counter for the updator + eplb_updator.cur_iterations = 0 + # Synchronize NPU to ensure all transfers are complete + torch_npu.npu.synchronize() + + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Expert resharding completed") + + self.worker.model_runner.shared_dict["scale"] = False + self.worker.model_runner.shared_dict["old_ep_size"] = None + self.worker.model_runner.shared_dict["new_ep_size"] = None + + def perform_eplb_reshuffle(self) -> None: + new_ep_size = get_ep_group().world_size + self.worker.model_runner.shared_dict["scale"] = True + self.worker.model_runner.shared_dict["new_ep_size"] = new_ep_size + + self._perform_eplb_reshuffle() + + def perform_scale_down_eplb_reshuffle(self, new_dp_size: int) -> None: + parallel_config = self.worker.vllm_config.parallel_config + tp_size = parallel_config.tensor_parallel_size + old_ep_size = parallel_config.data_parallel_size * tp_size + new_ep_size = new_dp_size * tp_size + + self.worker.model_runner.shared_dict["scale"] = True + self.worker.model_runner.shared_dict["old_ep_size"] = old_ep_size + self.worker.model_runner.shared_dict["new_ep_size"] = new_ep_size + + self._perform_eplb_reshuffle() + + def receive_expert_mapping(self) -> tuple[torch.Tensor, int, int]: + dp_group = get_dp_group() + assert isinstance(dp_group, StatelessGroupCoordinator) + expert_maps = broadcast_expert_mapping( + expert_maps=None, + group=dp_group, + src_rank=0, + ) + num_local_experts = (expert_maps[0, 0] != -1).sum().item() + num_logical_experts = expert_maps.shape[-1] + + return expert_maps, num_local_experts, num_logical_experts + + def prepare_new_worker(self) -> None: + moe_modules = [ + module + for module in self.worker.model_runner.model.modules() + if (module.__class__.__name__ == "AscendFusedMoE" or module.__class__.__name__ == "AscendSharedFusedMoE") + ] + for module in moe_modules: + with set_current_vllm_config(self.worker.vllm_config): + if hasattr(module.quant_method, "quant_method") and isinstance( + module.quant_method.quant_method, AscendW8A8DynamicFusedMoEMethod + ): + try: + device_group = get_mc2_group().device_group + # TODO: Try local_rank = ep_group.rank_in_group + local_rank = get_mc2_group().rank_in_group + backend = device_group._get_backend(torch.device("npu")) + module.quant_method.quant_method.moe_all_to_all_group_name = backend.get_hccl_comm_name( + local_rank + ) + except AttributeError: + module.quant_method.quant_method.moe_all_to_all_group_name = "" + setup_moe_comm_method(module.moe_config) diff --git a/vllm_ascend/distributed/elastic_ep/standby_state.py b/vllm_ascend/distributed/elastic_ep/standby_state.py new file mode 100644 index 00000000000..f8ffff72708 --- /dev/null +++ b/vllm_ascend/distributed/elastic_ep/standby_state.py @@ -0,0 +1,94 @@ +import torch +from vllm.distributed.parallel_state import ( + _init_stateless_group, + get_pp_group, + get_tp_group, + get_world_group, +) +from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator +from vllm.distributed.utils import get_cached_tcp_store_client + +from vllm_ascend.ascend_config import get_ascend_config + +_STANDBY_MC2: StatelessGroupCoordinator | None = None +_STANDBY_DYNAMIC_EPLB: StatelessGroupCoordinator | None = None +_STANDBY_FC3_QUANT_X: StatelessGroupCoordinator | None = None + + +def get_standby_mc2_group() -> StatelessGroupCoordinator | None: + return _STANDBY_MC2 + + +def get_standby_dynamic_eplb_group() -> StatelessGroupCoordinator | None: + return _STANDBY_DYNAMIC_EPLB + + +def get_standby_fc3_quant_x_group() -> StatelessGroupCoordinator | None: + return _STANDBY_FC3_QUANT_X + + +def create_ascend_standby_groups( + new_dp_size: int, + new_world_size_across_dp: int, + master_ip: str, + coord_store_port: int, + backend: str | None = None, +) -> None: + global _STANDBY_MC2, _STANDBY_DYNAMIC_EPLB, _STANDBY_FC3_QUANT_X + + assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size + world_group = get_world_group() + assert isinstance(world_group, StatelessGroupCoordinator) + backend = backend or world_group.backend + + coord_store = get_cached_tcp_store_client(master_ip, coord_store_port) + + tp_size = get_tp_group().world_size + pp_size = get_pp_group().world_size + + all_ranks = torch.arange(new_world_size_across_dp).reshape(-1, new_dp_size * pp_size * tp_size) + group_ranks = all_ranks.unbind(0) + standby_ep_ranks = [x.tolist() for x in group_ranks] + + _STANDBY_MC2 = _init_stateless_group( + standby_ep_ranks, + "mc2", + master_ip, + backend, + coord_store=coord_store, + use_device_communicator=False, + ) + + if get_ascend_config().eplb_config.dynamic_eplb: + _STANDBY_DYNAMIC_EPLB = _init_stateless_group( + standby_ep_ranks, + "dynamic_eplb", + master_ip, + backend, + coord_store=coord_store, + use_device_communicator=False, + ) + + if get_ascend_config().multistream_overlap_gate: + _STANDBY_FC3_QUANT_X = _init_stateless_group( + standby_ep_ranks, + "fc3_quant_x", + master_ip, + backend, + coord_store=coord_store, + use_device_communicator=False, + ) + + +def pop_ascend_standby_groups() -> dict: + """Return all standby groups and clear the standby state.""" + global _STANDBY_MC2, _STANDBY_DYNAMIC_EPLB, _STANDBY_FC3_QUANT_X + result = dict( + mc2=_STANDBY_MC2, + dynamic_eplb=_STANDBY_DYNAMIC_EPLB, + fc3_quant_x=_STANDBY_FC3_QUANT_X, + ) + _STANDBY_MC2 = None + _STANDBY_DYNAMIC_EPLB = None + _STANDBY_FC3_QUANT_X = None + return result diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 5ed7d3dd934..f502ebb8ecd 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -1,6 +1,13 @@ import torch from vllm.config import ParallelConfig, get_current_vllm_config -from vllm.distributed.parallel_state import GroupCoordinator, get_tp_group, get_world_group, init_model_parallel_group +from vllm.distributed import get_cached_tcp_store_client +from vllm.distributed.parallel_state import ( + GroupCoordinator, + _init_stateless_group, + get_tp_group, + get_world_group, + init_model_parallel_group, +) from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.utils import enable_dsa_cp_with_layer_shard, flashcomm2_enable @@ -33,12 +40,26 @@ def init_ascend_model_parallel( if model_parallel_initialized(): return assert torch.distributed.is_initialized() + enable_elastic_ep = parallel_config.enable_elastic_ep world_size = torch.distributed.get_world_size() backend = torch.distributed.get_backend(get_world_group().device_group) global_tp_size = parallel_config.tensor_parallel_size global_dp_size = parallel_config.data_parallel_size global_pp_size = parallel_config.pipeline_parallel_size global_pcp_size = parallel_config.prefill_context_parallel_size + if enable_elastic_ep: + coord_store = get_cached_tcp_store_client( + parallel_config.data_parallel_master_ip, + parallel_config._coord_store_port, + ) + # Use stateless world group for global information + world_size = get_world_group().world_size + tp_pp_pcp_size = global_tp_size * global_pp_size * global_pcp_size + local_all_ranks = torch.arange(tp_pp_pcp_size).reshape(global_pp_size, global_pcp_size, global_tp_size) + backend = "hccl" + else: + world_size = torch.distributed.get_world_size() + backend = torch.distributed.get_backend(get_world_group().device_group) # The layout of all ranks: ExternalDP * EP # ExternalDP is the data parallel group that is not part of the model, @@ -61,11 +82,17 @@ def init_ascend_model_parallel( num_head_replica = get_ascend_config().num_head_replica remote_tp_size = global_tp_size // pd_tp_ratio if num_head_replica <= 1: - group_ranks = all_ranks.view(-1, prefill_tensor_model_parallel_size).unbind(0) + if enable_elastic_ep: + group_ranks = local_all_ranks.view(-1, prefill_tensor_model_parallel_size).unbind(0) + else: + group_ranks = all_ranks.view(-1, prefill_tensor_model_parallel_size).unbind(0) else: - group_ranks = all_ranks.clone().view( - global_dp_size * global_pp_size * global_pcp_size, -1, num_head_replica - ) # [DP_size, num_head, num_head_replica] + if enable_elastic_ep: + group_ranks = local_all_ranks.clone().view(global_pp_size * global_pcp_size, -1, num_head_replica) + else: + group_ranks = all_ranks.clone().view( + global_dp_size * global_pp_size * global_pcp_size, -1, num_head_replica + ) # [DP_size, num_head, num_head_replica] group_ranks = group_ranks.permute(0, 2, 1) group_ranks = group_ranks.reshape(-1, group_ranks.size(-1)) # [DP_size * num_head_replica, num_head] alltoall_group_size = group_ranks.size(-1) // remote_tp_size @@ -93,19 +120,49 @@ def init_ascend_model_parallel( group_ranks = [x.tolist() for x in group_ranks] global _MC2 - _MC2 = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="mc2") + if enable_elastic_ep: + _MC2 = _init_stateless_group( + group_ranks, + "mc2", + parallel_config.data_parallel_master_ip, + backend, + coord_store=coord_store, + use_device_communicator=False, + ) + else: + _MC2 = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="mc2") if get_ascend_config().eplb_config.dynamic_eplb: global _DYNAMIC_EPLB - _DYNAMIC_EPLB = init_model_parallel_group( - group_ranks, get_world_group().local_rank, backend, group_name="dynamic_eplb" - ) + if enable_elastic_ep: + _DYNAMIC_EPLB = _init_stateless_group( + group_ranks, + "dynamic_eplb", + parallel_config.data_parallel_master_ip, + backend, + coord_store=coord_store, + use_device_communicator=False, + ) + else: + _DYNAMIC_EPLB = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, group_name="dynamic_eplb" + ) if get_ascend_config().multistream_overlap_gate: global _FC3_QUANT_X - _FC3_QUANT_X = init_model_parallel_group( - group_ranks, get_world_group().local_rank, backend, group_name="fc3_quant_x" - ) + if enable_elastic_ep: + _FC3_QUANT_X = _init_stateless_group( + group_ranks, + "fc3_quant_x", + parallel_config.data_parallel_master_ip, + backend, + coord_store=coord_store, + use_device_communicator=False, + ) + else: + _FC3_QUANT_X = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, group_name="fc3_quant_x" + ) # Initialize fine-grained TP process groups on Ascend for four components: # 1. LM Head: output logits projection (`lmhead_tensor_parallel_size`) @@ -226,6 +283,26 @@ def create_shard_weight_group(module_tp_group_ranks: None) -> GroupCoordinator: _SHARD_WEIGHT = create_shard_weight_group(tp_group_ranks) +def _replace_ascend_active_groups( + *, + mc2: GroupCoordinator | None, + dynamic_eplb: GroupCoordinator | None, + fc3_quant_x: GroupCoordinator | None, +) -> None: + """Destroy the current DP/EP/WORLD/EPLB groups and replace them. + + Destruction is collective — all ranks in the old groups must call this + function together. Pass all-``None`` to tear down without replacement. + """ + global _MC2, _DYNAMIC_EPLB, _FC3_QUANT_X + for group in (_MC2, _DYNAMIC_EPLB, _FC3_QUANT_X): + if group is not None: + group.destroy() + _MC2 = mc2 + _DYNAMIC_EPLB = dynamic_eplb + _FC3_QUANT_X = fc3_quant_x + + def model_parallel_initialized(): return _MC2 is not None diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py index 7cd71f89df1..4749e43d54e 100644 --- a/vllm_ascend/eplb/adaptor/vllm_adaptor.py +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -43,6 +43,8 @@ def __init__(self, model, **args): num_buffer_tensor = self.num_local_experts self.buffer_tensor_list: list[list[Any]] = [[] for _ in range(num_buffer_tensor)] + if self.model.quant_config is None and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: + self.temp_tensor_list: list[list[Any]] = [[] for _ in range(num_buffer_tensor)] self.init_buffer_tensor(num_buffer_tensor) self.log2phy_map_per_layer = dict() @@ -58,6 +60,9 @@ def init_buffer_tensor(self, num_buffer_tensor): expert_tensor = self.param_dict[complete_name][0] buffer_tensor = torch.empty_like(expert_tensor) self.buffer_tensor_list[buffer_id].append(buffer_tensor) + if self.model.quant_config is None and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: + temp_tensor = torch.empty_like(expert_tensor) + self.temp_tensor_list[buffer_id].append(temp_tensor) def init_expert_param_per_layer(self): self.param_dict = dict() diff --git a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py index 79321aca2d6..69a127457ae 100644 --- a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py +++ b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py @@ -52,13 +52,18 @@ def generate_expert_d2d_transfer_task(self, expert_send_info, expert_recv_info, self.layer_id = layer_id self.comm_op_list = [] + has_temp = set() for send_info in expert_send_info: dst_rank, global_expert_id_to_send = send_info local_expert_id = self.eplb_adaptor.expert_map_per_layer_cpu[layer_id][global_expert_id_to_send].item() - for src_tensor in self.eplb_adaptor.expert_param_per_layer[layer_id][local_expert_id]: - self.comm_op_list.append( - dist.P2POp(dist.isend, src_tensor, dst_rank, group=self.comm_group.device_group) - ) + op_info = {"peer_rank": dst_rank, "tensors": [], "expert_id": global_expert_id_to_send, "op": "send"} + for index, src_tensor in enumerate(self.eplb_adaptor.expert_param_per_layer[layer_id][local_expert_id]): + if hasattr(self.eplb_adaptor, "temp_tensor_list"): + if local_expert_id not in has_temp: + self.eplb_adaptor.temp_tensor_list[layer_id][index].copy_(src_tensor) + op_info["tensors"].append(self.eplb_adaptor.temp_tensor_list[layer_id][index]) + has_temp.add(local_expert_id) + self.comm_op_list.append(op_info) for buffer_tensor_id, recv_info in enumerate(expert_recv_info): recv_rank, global_expert_id_to_recv = recv_info @@ -81,8 +86,17 @@ def asyn_expert_weight_transfer(self, reqs): # set asynchronous stream for d2d expert weight transfer if self.comm_op_list: - ret_list = dist.batch_isend_irecv(self.comm_op_list) - reqs.extend(ret_list) + for op_info in self.comm_op_list: + peer_rank = op_info["peer_rank"] + tensors = op_info["tensors"] + expert_id = op_info["expert_id"] + op = op_info["op"] + for i, tensor in enumerate(tensors): + if op == "send": + worker = self.comm_group.device_group.send([tensor], peer_rank, tag=(expert_id + 1) * (i + 1)) + else: + worker = self.comm_group.device_group.recv([tensor], peer_rank, tag=(expert_id + 1) * (i + 1)) + reqs.append(worker) self.state = ExpertWeightUpdateState.TRANSFERRING diff --git a/vllm_ascend/eplb/core/eplb_worker.py b/vllm_ascend/eplb/core/eplb_worker.py index b9a62a0d1df..a10fcb9d415 100644 --- a/vllm_ascend/eplb/core/eplb_worker.py +++ b/vllm_ascend/eplb/core/eplb_worker.py @@ -19,7 +19,7 @@ import numpy as np import torch -import torch.distributed as dist +from vllm.distributed.parallel_state import get_ep_group from vllm.logger import logger from vllm_ascend.eplb.core.eplb_utils import generate_log2phy_map @@ -33,7 +33,7 @@ def __init__(self, shared_dict, policy_type, enable_d2d: bool = True): self.shared_dict = shared_dict self.old_expert_maps = None self.enable_d2d = enable_d2d - self.rank_id = dist.get_rank() + self.rank_id = get_ep_group().rank_in_group self.multi_stage = policy_type == 3 def do_update(self): @@ -50,6 +50,7 @@ def do_update(self): self.old_expert_maps = self.get_init_expert_maps() if self.old_expert_maps is not None: self.num_local_experts = self.old_expert_maps.max() + 1 + self.num_experts = self.old_expert_maps.shape[-1] else: raise ValueError("Failed to get expert_maps from shared_dict.") @@ -60,6 +61,16 @@ def do_update(self): # Get the updated expert table based on the workload information old_placement = self.global2local(self.old_expert_maps, self.num_local_experts) + scale = self.shared_dict.get("scale", False) + if scale: + old_ep_size = self.shared_dict["old_ep_size"] + new_ep_size = self.shared_dict["new_ep_size"] + assert old_ep_size != new_ep_size + self.policy.set_new_ep_size(new_ep_size) + if load_info.shape[1] > old_ep_size: + load_info = load_info[:, :old_ep_size] + if self.old_expert_maps.shape[1] > old_ep_size: + self.old_expert_maps = self.old_expert_maps[:, :old_ep_size] _, _, new_placement = self.calculate_rebalance_experts(load_info, old_placement) if self.rank_id == 0: @@ -80,6 +91,21 @@ def do_update(self): self.check_expert_placement(old_placement, new_placement) new_expert_maps = self.local2global(new_placement) self.update_expert_map(new_expert_maps) + new_expert_maps_clone = new_expert_maps.clone() + + if scale: + shape = list(new_expert_maps_clone.shape) + shape[1] = abs(old_ep_size - new_ep_size) + if old_ep_size > new_ep_size: + # when scale down, ensure that the shutdown ranks do not own any experts + # by setting their expert_map to all -1 + shutdown_rank_expert_maps = torch.full(shape, -1, dtype=new_expert_maps.dtype) + new_expert_maps = torch.cat([new_expert_maps, shutdown_rank_expert_maps], dim=1) + else: + # when scale up, ensure that new ranks do not own any experts + # by setting their expert_map to all -1 + new_rank_expert_maps = torch.full(shape, -1, dtype=new_expert_maps.dtype) + self.old_expert_maps = torch.cat([self.old_expert_maps, new_rank_expert_maps], dim=1) update_info = self.compose_expert_update_info_greedy(new_expert_maps, self.old_expert_maps) self.old_expert_maps = new_expert_maps @@ -91,19 +117,21 @@ def do_update(self): def check_expert_placement(self, old_placement, new_placement): num_layers = old_placement.shape[0] - num_ranks = old_placement.shape[1] + num_ranks = max(old_placement.shape[1], new_placement.shape[1]) for layer_id in range(num_layers): # check if any logical expert is not placed on any rank - if torch.unique(new_placement[layer_id]).numel() < torch.unique(old_placement[layer_id]).numel(): + if torch.unique(new_placement[layer_id]).numel() < self.num_experts: logger.error(f"There exists expert not placed on any rank in layer {layer_id}") new_placement[layer_id] = old_placement[layer_id] continue for rank_id in range(num_ranks): - new_placement_check = new_placement[layer_id][rank_id] - old_placement_check = old_placement[layer_id][rank_id] + new_placement_check = new_placement[layer_id][rank_id] if new_placement.shape[1] > rank_id else None + old_placement_check = old_placement[layer_id][rank_id] if old_placement.shape[1] > rank_id else None + if new_placement_check is None: + break # check if same logical experts are placed on the same NPU if new_placement_check.numel() != torch.unique(new_placement_check).numel(): logger.error( @@ -114,14 +142,27 @@ def check_expert_placement(self, old_placement, new_placement): break # check if there is any experts movement inside one NPU - expert_not_move = torch.isin(new_placement_check, old_placement_check) - if not torch.equal(new_placement_check[expert_not_move], old_placement_check[expert_not_move]): - logger.error( - "There exists expert movement inside NPU; expert placement on " - f"layer {layer_id}, rank {rank_id} is invalid" + if old_placement_check is None: + continue + expert_not_move_mask = torch.isin(new_placement_check, old_placement_check) + if not expert_not_move_mask.any(): + continue + expert_not_move = new_placement_check[expert_not_move_mask] + old_indices = [] + for expert in expert_not_move: + old_idx = torch.where(old_placement_check == expert)[0] + old_indices.append(old_idx.item()) + new_indices = torch.where(expert_not_move_mask)[0].tolist() + if old_indices != new_indices: + logger.info( + "There exists expert movement inside NPU, expert placement on" + f"layer {layer_id}, rank {rank_id} is invalid, try to rearrange it!" ) - new_placement[layer_id] = old_placement[layer_id] - break + new_placement_this_rank = new_placement_check.clone() + new_placement_this_rank[old_indices] = expert_not_move + available_positions = list(set(range(len(new_placement_check))) - set(old_indices)) + new_placement_this_rank[available_positions] = new_placement_check[~expert_not_move_mask] + new_placement[layer_id][rank_id] = new_placement_this_rank # TODO: Here only expert weight exchange is considered, need to be extended to cover other weight update cases def compose_expert_update_info_greedy(self, updated_expert_maps, current_expert_maps): diff --git a/vllm_ascend/eplb/core/policy/policy_abstract.py b/vllm_ascend/eplb/core/policy/policy_abstract.py index ce2a764cb05..79099751b9e 100644 --- a/vllm_ascend/eplb/core/policy/policy_abstract.py +++ b/vllm_ascend/eplb/core/policy/policy_abstract.py @@ -15,6 +15,9 @@ class EplbPolicy: def __init__(self, config: DynamicConfig): self.config = config + def set_new_ep_size(self, new_ep_size): + pass + @abstractmethod def rebalance_experts(self, current_expert_table, expert_workload): """ diff --git a/vllm_ascend/eplb/core/policy/policy_default_eplb.py b/vllm_ascend/eplb/core/policy/policy_default_eplb.py index 5348f301b83..6d76848ab4c 100644 --- a/vllm_ascend/eplb/core/policy/policy_default_eplb.py +++ b/vllm_ascend/eplb/core/policy/policy_default_eplb.py @@ -27,6 +27,10 @@ class DynamicTable: class DefaultEplb(EplbPolicy): def __init__(self, config: DynamicConfig): super().__init__(config) + self._new_ep_size = None + + def set_new_ep_size(self, new_ep_size: int): + self._new_ep_size = new_ep_size @staticmethod def add_redundant(current_expert_table, expert_workload, num_original_expert): @@ -53,10 +57,13 @@ def original_compute_balanced_pack_redundancy(origin_weights, card_num, num_redu for i in range(num_redundancy_expert): sorted_indices = np.argsort([t[1] for t in origin_weights], kind="stable")[::-1] weights = [origin_weights[idx] for idx in sorted_indices] - tmp_raw_weight = weights[0][1] * (len(route_expert_redundancy[weights[0][0]]) + 1) - route_expert_redundancy[weights[0][0]].append(route_expert_num + i) - avg_weight = tmp_raw_weight / (len(route_expert_redundancy[weights[0][0]]) + 1) - weights[0] = (weights[0][0], avg_weight) + index = 0 + while (len(route_expert_redundancy[weights[index][0]])) == card_num - 1: + index += 1 + tmp_raw_weight = weights[index][1] * (len(route_expert_redundancy[weights[index][0]]) + 1) + route_expert_redundancy[weights[index][0]].append(route_expert_num + i) + avg_weight = tmp_raw_weight / (len(route_expert_redundancy[weights[index][0]]) + 1) + weights[index] = (weights[index][0], avg_weight) origin_weights = weights # Step 2: Calculate the number of items per box @@ -83,6 +90,7 @@ def original_compute_balanced_pack_redundancy(origin_weights, card_num, num_redu box_weights[index] += cur_weight box_counts[index] += 1 index += 1 + index = index % card_num sorted_indices = np.argsort([t[1] for t in origin_weights], kind="stable")[::-1] origin_weights = [origin_weights[idx] for idx in sorted_indices] @@ -98,6 +106,17 @@ def original_compute_balanced_pack_redundancy(origin_weights, card_num, num_redu if min_box_index == -1 or box_weights[i] < box_weights[min_box_index]: min_box_index = i + if min_box_index == -1: + # Try to place in the last box first + if box_counts[-1] < items_per_box or (box_counts[-1] == items_per_box and remaining_items > 0): + min_box_index = -1 + else: + # Find any box with capacity + for i in range(card_num): + if box_counts[i] < items_per_box or (box_counts[i] == items_per_box and remaining_items > 0): + min_box_index = i + break + # Place the item (id) into the selected box boxes[min_box_index].append(item_id) boxes_weights[min_box_index].append(weight) @@ -108,7 +127,58 @@ def original_compute_balanced_pack_redundancy(origin_weights, card_num, num_redu if box_counts[min_box_index] == (items_per_box + 1) and remaining_items > 0: remaining_items -= 1 - # Step 5: Output each box's contents and total weight + # Step 5: Eliminate duplicate experts within the same NPU through redundancy + # reallocation. Replace duplicates with redundant copies from other + # experts based on minimal weight difference. + for i in range(card_num): + arr = np.asarray(boxes[i]) + unique, inv, cnt = np.unique(arr, return_inverse=True, return_counts=True) + mask = cnt > 1 + dup_vals = unique[mask] + dup_cnts = cnt[mask] + for item_id, counts in zip(dup_vals, dup_cnts): + for _ in range(counts - 1): + cur_position = boxes[i].index(item_id) + cur_weight = boxes_weights[i][cur_position] + sorted_indices = np.argsort( + [ + abs( + t[1] + * (len(route_expert_redundancy[t[0]]) + 1) + / (len(route_expert_redundancy[t[0]]) + 2) + - cur_weight + ) + for t in origin_weights + ], + kind="stable", + ) + weights = [origin_weights[idx] for idx in sorted_indices] + index = 0 + while index < len(weights): + if ( + len(route_expert_redundancy[weights[index][0]]) < card_num - 1 + and weights[index][0] != item_id + and weights[index][0] not in boxes[i] + ): + break + index += 1 + boxes[i][cur_position] = weights[index][0] + tmp_raw_weight = weights[index][1] * (len(route_expert_redundancy[weights[index][0]]) + 1) + route_expert_redundancy[weights[index][0]].append(0) + avg_weight = tmp_raw_weight / (len(route_expert_redundancy[weights[index][0]]) + 1) + boxes_weights[i][cur_position] = avg_weight + weights[index] = (weights[index][0], avg_weight) + tmp_raw_weight = cur_weight * (len(route_expert_redundancy[item_id]) + 1) + avg_weight = tmp_raw_weight / len(route_expert_redundancy[item_id]) + route_expert_redundancy[item_id].pop() + for index, (expert_id, expert_weight) in enumerate(weights): + if item_id == expert_id: + weights[index] = (expert_id, avg_weight) + origin_weights = weights + + box_weights = [sum(boxes_weights[i]) for i in range(card_num)] + + # Step 6: Output each box's contents and total weight result = [] for i in range(card_num): result.append( @@ -292,8 +362,12 @@ def rebalance_experts(self, current_expert_table, expert_workload): assert info.placement_table is not None row = cast(np.ndarray, info.placement_table[0]) expert_ids, counts = np.unique(row, return_counts=True) - num_redundancy_expert = self.get_redundant_num(num_npus, counts) num_original_expert = len(expert_ids) + if self._new_ep_size: + num_npus = self._new_ep_size + num_redundancy_expert = experts_per_npu * self._new_ep_size - num_original_expert + else: + num_redundancy_expert = self.get_redundant_num(num_npus, counts) layer_workloads = self.add_redundant(info.placement_table, info.workload_table, num_original_expert) max_heat_per_layer_before = self.calculate_max_heat_per_layer(info.workload_table, layer_num) npu_heat_all_origin = sum(max_heat_per_layer_before) @@ -311,11 +385,15 @@ def rebalance_experts(self, current_expert_table, expert_workload): if num_npus <= 0: raise ValueError("the number of NPUs must be greater than 0") - if num_npus < num_redundancy_expert: + if experts_per_npu > expert_num: raise ValueError( - "the number of NPUs " - f"{num_npus} must be greater than or equal to the number of redundant experts " - f"{num_redundancy_expert}" + f"the number of experts per NPU {experts_per_npu} can't be greater than expert_num {expert_num}" + ) + + if num_npus * experts_per_npu < num_original_expert: + raise ValueError( + f"num_npus {num_npus} * experts_per_npu {experts_per_npu} " + f"can't be less than num_original_expert {num_original_expert}" ) # Number of experts deployed on each card includes one redundant expert @@ -324,7 +402,7 @@ def rebalance_experts(self, current_expert_table, expert_workload): max_heat_per_layer_after = np.zeros([layer_num]) for layer in range(layer_num): # Get the expert IDs and their corresponding workloads for the current layer; - # workloads need to be normalized, and one redundant expert is added per card + # redundant experts will be created and distributed during the packing process weights = np.zeros((expert_num,), dtype="object") for expert_id, workload_weight in enumerate(layer_workloads[layer]): weights[expert_id] = (expert_id, workload_weight) @@ -337,10 +415,11 @@ def rebalance_experts(self, current_expert_table, expert_workload): global_deployment[layer] = layer_deployment max_heat_per_layer_after[layer] = max(result, key=lambda x: x["total_weight"])["total_weight"] - new_global_deployment = self.constraint_expert_local_exchange(current_expert_table, global_deployment) # Obtain the priority of each layer layer_changed_ratio = [] for layer_idx in range(layer_num): + if max_heat_per_layer_before[layer_idx] == 0: + max_heat_per_layer_before[layer_idx] = 1e-6 layer_changed_ratio.append(max_heat_per_layer_after[layer_idx] / max_heat_per_layer_before[layer_idx]) per_layer_priority = np.argsort(layer_changed_ratio) @@ -350,4 +429,6 @@ def rebalance_experts(self, current_expert_table, expert_workload): if npu_heat_all_after < 0.95 * npu_heat_all_origin: change = 1 - return change, per_layer_priority, np.array(new_global_deployment).tolist() + self._new_ep_size = None + + return change, per_layer_priority, np.array(global_deployment).tolist() diff --git a/vllm_ascend/eplb/eplb_updator.py b/vllm_ascend/eplb/eplb_updator.py index 285bd486e71..5bdc8cb8ff8 100644 --- a/vllm_ascend/eplb/eplb_updator.py +++ b/vllm_ascend/eplb/eplb_updator.py @@ -17,7 +17,6 @@ # Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this updator. import numpy import torch -import torch.distributed as dist import vllm.envs as envs from vllm.logger import logger @@ -31,22 +30,22 @@ class EplbUpdator: def __init__(self, eplb_config, loader: D2DExpertWeightLoader, eplb_process: EplbProcess, process): self.eplb_config = eplb_config self.multi_stage = eplb_config.eplb_policy_type == 3 + self.comm_group = get_dynamic_eplb_group() self.init_eplb(self.eplb_config.expert_map_path, process) self.eplb_loader = loader self.eplb_process = eplb_process self.shared_dict = self.eplb_process.shared_dict - self.comm_group = get_dynamic_eplb_group() def set_adaptor(self, adaptor: VllmEplbAdaptor): self.adaptor = adaptor self.num_moe_layers = self.adaptor.num_moe_layers local_load = self.adaptor.get_rank_expert_workload() - self.world_size = dist.get_world_size() + self.world_size = self.comm_group.world_size self.device = local_load.device self.eplb_loader.num_layers = self.adaptor.num_dense_layers + self.adaptor.num_moe_layers def init_eplb(self, expert_map_path, process): - self.rank_id = dist.get_rank() + self.rank_id = self.comm_group.rank_in_group self.num_expert_load_gather = 10 self.periodic_load_gather = True self.expert_heat_collection_interval: torch.int64 = self.eplb_config.expert_heat_collection_interval @@ -131,10 +130,11 @@ def forward_end(self): self.update_iteration() def compute_and_set_moe_load(self): - local_load = self.adaptor.get_rank_expert_workload() - moe_load = ( - self.comm_group.all_gather(local_load, dim=0).reshape(-1, self.world_size, *local_load.shape[1:]).cpu() - ) + self.world_size = self.comm_group.world_size + local_load = self.adaptor.get_rank_expert_workload().cpu() + gather_buffer = [torch.empty_like(local_load) for _ in range(self.world_size)] + self.comm_group.cpu_group.allgather(gather_buffer, local_load).wait() + moe_load = torch.stack(gather_buffer).permute(1, 0, 2) if self.multi_stage: moe_load = moe_load.permute(2, 0, 1, 3) @@ -145,27 +145,32 @@ def compute_and_set_moe_load(self): return moe_load def warm_up_eplb(self): - self.shared_dict["expert_maps"] = self.adaptor.get_global_expert_map() + if self.shared_dict["expert_maps"] is None: + self.shared_dict["expert_maps"] = self.adaptor.get_global_expert_map() self.compute_and_set_moe_load() src_tensor = torch.empty((1,), device=self.device) comm_op_list = [] + for src_rank in range(self.world_size): + for dst_rank in range(self.world_size): + if src_rank != dst_rank: + comm_op_list.append({"src_rank": src_rank, "dst_rank": dst_rank, "tensor": src_tensor}) - for dst_rank in range(self.world_size): - if dst_rank == self.rank_id: - continue - comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank, group=self.comm_group.device_group)) + comm_op_list = sorted(comm_op_list, key=lambda x: (x["src_rank"], x["dst_rank"])) - for src_rank in range(self.world_size): + workers = [] + for i, op in enumerate(comm_op_list): + src_rank = op["src_rank"] + dst_rank = op["dst_rank"] + tensor = op["tensor"] if src_rank == self.rank_id: - continue - comm_op_list.append(dist.P2POp(dist.irecv, src_tensor, src_rank, group=self.comm_group.device_group)) - if comm_op_list: - reqs = dist.batch_isend_irecv(comm_op_list) + workers.append(self.comm_group.device_group.send([tensor], dst_rank, tag=i)) + elif dst_rank == self.rank_id: + workers.append(self.comm_group.device_group.recv([tensor], src_rank, tag=i)) - for req in reqs: - req.wait() + for worker in workers: + worker.wait() def shutdown(self): """ diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 70b9e2bcd28..b6ea2ebe5b6 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os from collections.abc import Callable from dataclasses import dataclass, field from functools import wraps @@ -384,7 +385,8 @@ def __init__(self, *args, **kwargs): self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.enable_npugraph_ex_static_kernel = ascend_config.ascend_compilation_config.enable_static_kernel - setup_moe_comm_method(self.moe_config) + if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") != "1": + setup_moe_comm_method(self.moe_config) self.quant_type = self._get_quant_type() self.runner = self._init_runner() diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index 6112ec9e819..0d3eaa5baf1 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -86,7 +86,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) device_group = get_mc2_group().device_group # TODO: Try local_rank = ep_group.rank_in_group - local_rank = torch.distributed.get_rank(group=device_group) + local_rank = get_mc2_group().rank_in_group backend = device_group._get_backend(torch.device("npu")) self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank) self.ep_rank_id = get_mc2_group().rank_in_group @@ -403,7 +403,7 @@ def __init__(self, **kwargs): ) # TODO: Try local_rank = ep_group.rank_in_group - local_rank = torch.distributed.get_rank(group=self.ep_group) + local_rank = get_ep_group().rank_in_group backend = self.ep_group._get_backend(torch.device("npu")) self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index f09c524e60f..a9ae5348cec 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -19,11 +19,13 @@ import math import os +from datetime import timedelta from typing import TYPE_CHECKING, Any from uuid import uuid4 import torch import vllm.envs as envs_vllm +from torch.distributed.distributed_c10d import Backend, PrefixStore, ProcessGroup from vllm.logger import logger from vllm.platforms import Platform, PlatformEnum @@ -865,3 +867,76 @@ def _fix_incompatible_config(vllm_config: VllmConfig) -> None: @classmethod def use_custom_op_collectives(cls) -> bool: return True + + @classmethod + def stateless_init_device_torch_dist_pg( + cls, + backend: str, + prefix_store: PrefixStore, + group_rank: int, + group_size: int, + timeout: timedelta, + ) -> ProcessGroup: + """ + Initialize a stateless HCCL process group for CUDA devices. + This method creates a ProcessGroup with the specified backend configuration, + typically used for GPU communication. It sets up the necessary backend + options and registers the backend with the process group. + Args: + backend: The distributed backend to use (e.g., 'hccl') + prefix_store: The prefix store for distributed coordination + group_rank: The rank of the current process within the group + group_size: The total number of processes in the group + timeout: Maximum time to wait for the operation to complete + **kwargs: Additional backend-specific options + warning: + Uses internal PyTorch API (torch._C._distributed_c10d.ProcessGroupHCCL) + which may change in future PyTorch versions. Compatibility should be + verified with each PyTorch upgrade. + Compatibility Risk: + - High risk of breakage in PyTorch 2.4+ + - No semantic versioning guarantees + - Requires testing with new PyTorch releases + Returns: + A ProcessGroup object configured with the specified backend + """ + + # INTERNAL API USAGE - COMPATIBILITY RISK + # This internal import is necessary for stateless process group functionality + # but carries compatibility risks. Monitor PyTorch release notes for changes. + # TODO: Migrate to public API when available in future PyTorch versions + from torch_npu._C._distributed_c10d import ProcessGroupHCCL + import uuid + + pg = ProcessGroup(prefix_store, group_rank, group_size) + + backend_options = ProcessGroupHCCL.Options() + backend_options._timeout = timeout + + # Create Backend object + backend = Backend("hccl") + + # Set default backend for ProcessGroup + pg._set_default_backend(Backend.backend_type_map[backend]) + + device = torch.device("npu") + if hasattr(backend_options, "_device"): + backend_options._device = device + + backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size, backend_options) + + backend_class._set_sequence_number_for_group() + backend_type = ProcessGroup.BackendType.CUSTOM + pg._register_backend(device, backend_type, backend_class) + hccl_comm_name = None + if group_rank == 0: + hccl_comm_name = uuid.uuid4().hex + pg.get_group_store().set("hccl_comm_name", hccl_comm_name) + else: + hccl_comm_name = pg.get_group_store().get("hccl_comm_name").decode("utf-8") + if hccl_comm_name is not None: + group_desc = "undefined" + backend_class._set_hccl_comm_name(hccl_comm_name) + pg._set_group_desc(group_desc) + + return pg diff --git a/vllm_ascend/quantization/methods/w8a8_dynamic.py b/vllm_ascend/quantization/methods/w8a8_dynamic.py index 09e2f964455..90d9fcbddd6 100644 --- a/vllm_ascend/quantization/methods/w8a8_dynamic.py +++ b/vllm_ascend/quantization/methods/w8a8_dynamic.py @@ -15,6 +15,7 @@ # limitations under the License. # +import os from collections.abc import Callable from typing import Any @@ -127,14 +128,15 @@ def __init__(self): self.in_dtype = vllm_config.model_config.dtype self.supports_eplb = True - try: - device_group = get_mc2_group().device_group - # TODO: Try local_rank = ep_group.rank_in_group - local_rank = torch.distributed.get_rank(group=device_group) - backend = device_group._get_backend(torch.device("npu")) - self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank) - except AttributeError: - self.moe_all_to_all_group_name = "" + if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") != "1": + try: + device_group = get_mc2_group().device_group + # TODO: Try local_rank = ep_group.rank_in_group + local_rank = get_mc2_group().rank_in_group + backend = device_group._get_backend(torch.device("npu")) + self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank) + except AttributeError: + self.moe_all_to_all_group_name = "" def get_weight( self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py index 3c812aa4432..f40c275092a 100644 --- a/vllm_ascend/worker/block_table.py +++ b/vllm_ascend/worker/block_table.py @@ -192,7 +192,7 @@ def commit_slot_mapping(self, num_tokens: int) -> None: self.slot_mapping.copy_to_gpu(num_tokens) def clear(self) -> None: - self.block_table.fill_(0) + self.block_table.gpu.fill_(0) self.block_table.cpu.fill_(0) def _convert_physical_to_logical_blocks(self, physical_blocks: np.ndarray) -> np.ndarray: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 182f16a3728..c2cc05a9c00 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -18,6 +18,7 @@ # import math +import os import sys from collections import defaultdict from contextlib import contextmanager, nullcontext @@ -359,7 +360,16 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.policy_type = eplb_config.eplb_policy_type self.eplb_loader = D2DExpertWeightLoader() self.manager = Manager() - self.shared_dict = self.manager.dict({"expert_map": None, "moe_load": None, "expert_maps": None}) + self.shared_dict = self.manager.dict( + { + "expert_map": None, + "moe_load": None, + "expert_maps": None, + "scale": False, + "old_ep_size": None, + "new_ep_size": None, + } + ) self.eplb_process = EplbProcess(shared_dict=self.shared_dict, policy_type=self.policy_type, enable_d2d=True) self.process = self.eplb_process._launch_process() self.eplb_updator = EplbUpdator(eplb_config, self.eplb_loader, self.eplb_process, self.process) @@ -2536,7 +2546,6 @@ def _dummy_sampler_run( return output def profile_run(self) -> None: - self.eplb_warmup() mc2_tokens_capacity = get_mc2_tokens_capacity() if self.max_num_tokens > mc2_tokens_capacity and select_moe_comm_method( mc2_tokens_capacity, self.vllm_config @@ -2556,9 +2565,10 @@ def eplb_warmup(self): self.eplb_adaptor = VllmEplbAdaptor(model=self.model) self.eplb_loader.set_adator(self.eplb_adaptor) self.eplb_updator.set_adaptor(self.eplb_adaptor) - self.eplb_updator.warm_up_eplb() + if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") != "1": + self.eplb_updator.warm_up_eplb() - def load_model(self) -> None: + def load_model(self, load_dummy_weights: bool = False) -> None: logger.info("Starting to load model %s...", self.model_config.model) if self.ascend_config.mix_placement: @@ -2572,7 +2582,9 @@ def mock_true(): with DeviceMemoryProfiler() as m: # noqa: SIM117 if self.eplb_enable: self.vllm_config.parallel_config.enable_eplb = True - self.model: nn.Module = get_model(vllm_config=self.vllm_config) + if load_dummy_weights: + self.load_config.load_format = "dummy" + self.model = get_model(vllm_config=self.vllm_config, load_config=self.load_config) if self.dynamic_eplb: model_register(self.model) if self.drafter: diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index c5405a7c3a0..53abd0a1065 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -116,6 +116,10 @@ def __init__( is_driver_worker=is_driver_worker, ) + from vllm_ascend.distributed.elastic_ep.elastic_execute import AscendElasticEPScalingExecutor + + self.elastic_ep_executor = AscendElasticEPScalingExecutor(self) + if self.cache_config.cache_dtype == "auto": self.cache_dtype = self.model_config.dtype else: @@ -426,7 +430,7 @@ def execute_model( def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput | AsyncModelRunnerOutput: return self.model_runner.sample_tokens(grammar_output) - def load_model(self) -> None: + def load_model(self, *, load_dummy_weights: bool = False) -> None: if self.vllm_config.model_config.enable_sleep_mode: allocator = CaMemAllocator.get_instance() assert allocator.get_current_usage() == 0, "Sleep mode can only be used for one instance per process." @@ -437,7 +441,9 @@ def load_model(self) -> None: context = nullcontext() # type: ignore with context, set_current_vllm_config(self.vllm_config): - self.model_runner.load_model() + self.model_runner.load_model(load_dummy_weights) + + self.model_runner.eplb_warmup() def compile_or_warm_up_model(self) -> float: # Note: need to adapt for graph mode. @@ -661,6 +667,9 @@ def check_health(self) -> None: logger.info(f"query NPU card {self.local_rank} fail: {e}") return + def elastic_ep_execute(self, execute_method: str, *args, **kwargs): + return self.elastic_ep_executor.execute(execute_method, *args, **kwargs) + def parse_text_output(output) -> None: lines = output.strip().split("\n") From aac1c936c69b4866b502ac942eb30fab00cab3bc Mon Sep 17 00:00:00 2001 From: nifeng <1542305589@qq.com> Date: Fri, 3 Apr 2026 17:32:25 +0800 Subject: [PATCH 2/4] fix some bugs by copilot's review Signed-off-by: nifeng <1542305589@qq.com> --- vllm_ascend/compilation/acl_graph.py | 6 +- .../device_communicators/pyhccl.py | 7 - .../distributed/elastic_ep/elastic_execute.py | 123 +++++++++++++++++- vllm_ascend/distributed/parallel_state.py | 1 - vllm_ascend/eplb/adaptor/vllm_adaptor.py | 11 +- .../eplb/core/eplb_device_transfer_loader.py | 20 +-- vllm_ascend/eplb/core/eplb_worker.py | 4 +- vllm_ascend/eplb/eplb_updator.py | 3 +- vllm_ascend/platform.py | 34 ++--- 9 files changed, 165 insertions(+), 44 deletions(-) diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 6b5ad1e51a4..403f9f1b956 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -61,13 +61,13 @@ class ACLGraphWrapper: guaranteed when VLLM_LOGGING_LEVEL == "DEBUG". """ - all_instances: ClassVar[weakref.WeakSet["ACLGraphWrapper"]] = weakref.WeakSet() + _all_instances: ClassVar[weakref.WeakSet["ACLGraphWrapper"]] = weakref.WeakSet() graph_pool: ClassVar[tuple[int, int]] = current_platform.get_global_graph_pool() @classmethod def clear_all_graphs(cls) -> None: """Clear all graphs from all ACLGraphWrapper instances.""" - for instance in list(cls.all_instances): + for instance in list(cls._all_instances): instance.clear_graphs() cls.graph_pool = (cls.graph_pool[0], cls.graph_pool[1] + 1) @@ -98,6 +98,8 @@ def __init__( # aclgraphs for. self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry] = {} + ACLGraphWrapper._all_instances.add(self) + def __getattr__(self, key: str): # allow accessing the attributes of the runnable. if hasattr(self.runnable, key): diff --git a/vllm_ascend/distributed/device_communicators/pyhccl.py b/vllm_ascend/distributed/device_communicators/pyhccl.py index 9c0c797e9b2..607fe09e4cc 100644 --- a/vllm_ascend/distributed/device_communicators/pyhccl.py +++ b/vllm_ascend/distributed/device_communicators/pyhccl.py @@ -120,13 +120,6 @@ def __init__( with torch.npu.device(device): self.comm: hcclComm_t = self.hccl.hcclCommInitRank(self.world_size, self.unique_id, self.rank) - stream = current_stream() - # A small all_reduce for warmup. - data = torch.zeros(1, device=device) - self.all_reduce(data) - stream.synchronize() - del data - def destroy(self): if self.available and not self.disabled: with torch.accelerator.device_index(self.device.index): diff --git a/vllm_ascend/distributed/elastic_ep/elastic_execute.py b/vllm_ascend/distributed/elastic_ep/elastic_execute.py index 6aef9e1d1dc..94653efdab4 100644 --- a/vllm_ascend/distributed/elastic_ep/elastic_execute.py +++ b/vllm_ascend/distributed/elastic_ep/elastic_execute.py @@ -1,9 +1,12 @@ import copy import gc +from collections.abc import Iterable, Sequence import numpy as np import torch +import torch.nn as nn import torch_npu +from torch.distributed import P2POp from vllm.compilation.counter import compilation_counter from vllm.compilation.wrapper import reset_compile_wrapper from vllm.config import ( @@ -45,6 +48,53 @@ from vllm_ascend.quantization.methods.w8a8_dynamic import AscendW8A8DynamicFusedMoEMethod +def batch_transfer_weights( + model: nn.Module, + is_sender: bool, + peer_rank: int, + dp_group: StatelessGroupCoordinator, + expert_weights: Sequence[Iterable[torch.Tensor]], +) -> None: + device_comm = dp_group.device_communicator + if device_comm is None: + raise ValueError("No device communicator found") + + expert_weights_set = set() + for weight_group in expert_weights: + for weight in weight_group: + expert_weights_set.add(weight.data_ptr()) + + state_dict = model.state_dict() + all_params = [] + + for name, param in state_dict.items(): + if name.endswith("expert_map"): + continue + if param.data_ptr() not in expert_weights_set: + all_params.append(param.data) + + quant_weight_names = ["aclnn_input_scale", "aclnn_input_scale_reciprocal", "aclnn_input_offset"] + for module in model.modules(): + for name in quant_weight_names: + if (param := getattr(module, name, None)) is not None: + all_params.append(param) + + assert len(all_params) > 0 + p2p_ops = [] + for param in all_params: + op = object.__new__(P2POp) + if is_sender: + op.op = torch.distributed.isend + op.tensor = param + else: + op.op = torch.distributed.irecv + op.tensor = param + op.group_peer = peer_rank + p2p_ops.append(op) + + device_comm.batch_isend_irecv(p2p_ops) + + def broadcast_expert_mapping( expert_maps: torch.Tensor | None, group: StatelessGroupCoordinator, @@ -114,6 +164,45 @@ def create_standby_groups(self, reconfig_request: ReconfigureDistributedRequest) coord_store_port=reconfig_request.coord_store_port, ) + def transfer_weights(self, old_dp_size: int, new_dp_size: int) -> None: + standby_dp_group = get_standby_dp_group() + assert standby_dp_group is not None + # Broadcast old_dp_size to all workers in standby group + if standby_dp_group.rank_in_group < old_dp_size: + old_dp_size_tensor = torch.tensor([old_dp_size], dtype=torch.int64, device="cpu") + else: + old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu") + old_dp_size_tensor = standby_dp_group.tcp_store_group.broadcast(old_dp_size_tensor, 0) + + num_new_workers = new_dp_size - old_dp_size + dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank + + # Sender-receiver pairing: the first new_workers % old_dp_size + # senders get (k+1) contiguous receivers, the rest get k + # receivers. + num_dst_per_sender = num_new_workers // old_dp_size + remainder = num_new_workers % old_dp_size + + if dp_rank < remainder: + recv_begin = dp_rank * (num_dst_per_sender + 1) + recv_end = recv_begin + num_dst_per_sender + 1 + else: + recv_begin = remainder * (num_dst_per_sender + 1) + (dp_rank - remainder) * num_dst_per_sender + recv_end = recv_begin + num_dst_per_sender + + ranks_to_send = list(range(old_dp_size + recv_begin, old_dp_size + recv_end)) + + model = self.worker.model_runner.get_model() + for new_worker_rank in sorted(ranks_to_send): + batch_transfer_weights( + model=model, + is_sender=True, + peer_rank=new_worker_rank, + dp_group=standby_dp_group, + expert_weights=model.expert_weights, + ) + torch.accelerator.synchronize() + def broadcast_expert_mapping(self): standby_dp_group = get_standby_dp_group() assert standby_dp_group is not None @@ -294,9 +383,9 @@ def perform_eplb_reshuffle(self) -> None: self._perform_eplb_reshuffle() def perform_scale_down_eplb_reshuffle(self, new_dp_size: int) -> None: + old_ep_size = get_ep_group().world_size parallel_config = self.worker.vllm_config.parallel_config tp_size = parallel_config.tensor_parallel_size - old_ep_size = parallel_config.data_parallel_size * tp_size new_ep_size = new_dp_size * tp_size self.worker.model_runner.shared_dict["scale"] = True @@ -305,6 +394,38 @@ def perform_scale_down_eplb_reshuffle(self, new_dp_size: int) -> None: self._perform_eplb_reshuffle() + def receive_weights(self) -> None: + dp_group = get_dp_group() + assert isinstance(dp_group, StatelessGroupCoordinator) + new_dp_size = dp_group.world_size + dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank + + # Receive old_dp_size broadcasted during transfer_weights + old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu") + old_dp_size_tensor = dp_group.tcp_store_group.broadcast(old_dp_size_tensor, 0) + old_dp_size = int(old_dp_size_tensor[0].item()) + + # Calculate which existing worker will send to this new worker + num_new_workers = new_dp_size - old_dp_size + new_worker_idx = dp_rank - old_dp_size + num_dst_per_sender = num_new_workers // old_dp_size + remainder = num_new_workers % old_dp_size + + if new_worker_idx < remainder * (num_dst_per_sender + 1): + sender_rank = new_worker_idx // (num_dst_per_sender + 1) + else: + sender_rank = remainder + (new_worker_idx - remainder * (num_dst_per_sender + 1)) // num_dst_per_sender + + model = self.worker.model_runner.get_model() + batch_transfer_weights( + model=model, + is_sender=False, + peer_rank=sender_rank, + dp_group=dp_group, + expert_weights=model.expert_weights, + ) + torch.accelerator.synchronize() + def receive_expert_mapping(self) -> tuple[torch.Tensor, int, int]: dp_group = get_dp_group() assert isinstance(dp_group, StatelessGroupCoordinator) diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index f502ebb8ecd..951a6161f67 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -42,7 +42,6 @@ def init_ascend_model_parallel( assert torch.distributed.is_initialized() enable_elastic_ep = parallel_config.enable_elastic_ep world_size = torch.distributed.get_world_size() - backend = torch.distributed.get_backend(get_world_group().device_group) global_tp_size = parallel_config.tensor_parallel_size global_dp_size = parallel_config.data_parallel_size global_pp_size = parallel_config.pipeline_parallel_size diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py index 4749e43d54e..3a729ad8f81 100644 --- a/vllm_ascend/eplb/adaptor/vllm_adaptor.py +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -19,10 +19,10 @@ from typing import Any import torch -import torch.distributed as dist from vllm.logger import logger import vllm_ascend.envs as envs_ascend +from vllm_ascend.distributed.parallel_state import get_dynamic_eplb_group from vllm_ascend.quantization.methods.base import QuantType @@ -30,8 +30,8 @@ class VllmEplbAdaptor: def __init__(self, model, **args): super().__init__(**args) self.model = model - self.rank_id = dist.get_rank() - self.world_size = dist.get_world_size() + self.rank_id = get_dynamic_eplb_group().rank_in_group + self.world_size = get_dynamic_eplb_group().world_size self.num_dense_layers = getattr(self.model.config, "first_k_dense_replace", 0) self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers @@ -126,7 +126,10 @@ def _export_tensor_to_file(self, expert_maps, expert_map_record_path: str): json.dump(record, f, indent=4) def do_update_expert_map(self, layer_id, updated_expert_map): - self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map) + if layer_id in self.expert_map_per_layer_cpu: + self.expert_map_per_layer_cpu[layer_id].copy_(updated_expert_map) + else: + self.expert_map_per_layer_cpu[layer_id] = updated_expert_map.cpu() def do_update_expert_weight(self, layer_id, local_expert_to_replace, buffer_tensor_id): for expert_tensor, buffer_tensor in zip( diff --git a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py index 69a127457ae..6a02e9c74bd 100644 --- a/vllm_ascend/eplb/core/eplb_device_transfer_loader.py +++ b/vllm_ascend/eplb/core/eplb_device_transfer_loader.py @@ -16,7 +16,6 @@ # from enum import Enum -import torch.distributed as dist from vllm.logger import logger from vllm_ascend.distributed.parallel_state import get_dynamic_eplb_group @@ -60,19 +59,23 @@ def generate_expert_d2d_transfer_task(self, expert_send_info, expert_recv_info, for index, src_tensor in enumerate(self.eplb_adaptor.expert_param_per_layer[layer_id][local_expert_id]): if hasattr(self.eplb_adaptor, "temp_tensor_list"): if local_expert_id not in has_temp: - self.eplb_adaptor.temp_tensor_list[layer_id][index].copy_(src_tensor) - op_info["tensors"].append(self.eplb_adaptor.temp_tensor_list[layer_id][index]) + self.eplb_adaptor.temp_tensor_list[local_expert_id][index].copy_(src_tensor) + op_info["tensors"].append(self.eplb_adaptor.temp_tensor_list[local_expert_id][index]) + else: + op_info["tensors"].append(src_tensor) has_temp.add(local_expert_id) self.comm_op_list.append(op_info) for buffer_tensor_id, recv_info in enumerate(expert_recv_info): recv_rank, global_expert_id_to_recv = recv_info + op_info = {"peer_rank": recv_rank, "tensors": [], "expert_id": global_expert_id_to_recv, "op": "recv"} for buffer_tensor in self.eplb_adaptor.buffer_tensor_list[buffer_tensor_id]: - self.comm_op_list.append( - dist.P2POp(dist.irecv, buffer_tensor, recv_rank, group=self.comm_group.device_group) - ) + op_info["tensors"].append(buffer_tensor) local_expert_to_replace = self.updated_expert_map[global_expert_id_to_recv].item() self.recv_expert_list.append((local_expert_to_replace, buffer_tensor_id)) + self.comm_op_list.append(op_info) + + self.comm_op_list = sorted(self.comm_op_list, key=lambda x: x["expert_id"]) self.state = ExpertWeightUpdateState.READY @@ -92,10 +95,11 @@ def asyn_expert_weight_transfer(self, reqs): expert_id = op_info["expert_id"] op = op_info["op"] for i, tensor in enumerate(tensors): + tag = expert_id * len(tensors) + i if op == "send": - worker = self.comm_group.device_group.send([tensor], peer_rank, tag=(expert_id + 1) * (i + 1)) + worker = self.comm_group.device_group.send([tensor], peer_rank, tag=tag) else: - worker = self.comm_group.device_group.recv([tensor], peer_rank, tag=(expert_id + 1) * (i + 1)) + worker = self.comm_group.device_group.recv([tensor], peer_rank, tag=tag) reqs.append(worker) self.state = ExpertWeightUpdateState.TRANSFERRING diff --git a/vllm_ascend/eplb/core/eplb_worker.py b/vllm_ascend/eplb/core/eplb_worker.py index a10fcb9d415..8fc8315e7cd 100644 --- a/vllm_ascend/eplb/core/eplb_worker.py +++ b/vllm_ascend/eplb/core/eplb_worker.py @@ -94,7 +94,7 @@ def do_update(self): new_expert_maps_clone = new_expert_maps.clone() if scale: - shape = list(new_expert_maps_clone.shape) + shape = list(new_expert_maps.shape) shape[1] = abs(old_ep_size - new_ep_size) if old_ep_size > new_ep_size: # when scale down, ensure that the shutdown ranks do not own any experts @@ -108,7 +108,7 @@ def do_update(self): self.old_expert_maps = torch.cat([self.old_expert_maps, new_rank_expert_maps], dim=1) update_info = self.compose_expert_update_info_greedy(new_expert_maps, self.old_expert_maps) - self.old_expert_maps = new_expert_maps + self.old_expert_maps = new_expert_maps_clone logger.debug("EPLB Process compute complete") packed_update_info = self.pack_update_info(update_info) diff --git a/vllm_ascend/eplb/eplb_updator.py b/vllm_ascend/eplb/eplb_updator.py index 5bdc8cb8ff8..ab70012dba8 100644 --- a/vllm_ascend/eplb/eplb_updator.py +++ b/vllm_ascend/eplb/eplb_updator.py @@ -145,8 +145,7 @@ def compute_and_set_moe_load(self): return moe_load def warm_up_eplb(self): - if self.shared_dict["expert_maps"] is None: - self.shared_dict["expert_maps"] = self.adaptor.get_global_expert_map() + self.shared_dict["expert_maps"] = self.adaptor.get_global_expert_map() self.compute_and_set_moe_load() src_tensor = torch.empty((1,), device=self.device) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index a9ae5348cec..452970da19f 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -878,27 +878,27 @@ def stateless_init_device_torch_dist_pg( timeout: timedelta, ) -> ProcessGroup: """ - Initialize a stateless HCCL process group for CUDA devices. + Initialize a stateless HCCL process group for Ascend NPU. This method creates a ProcessGroup with the specified backend configuration, - typically used for GPU communication. It sets up the necessary backend - options and registers the backend with the process group. + typically used for NPU communication via HCCL. It sets up the necessary + backend options and registers the backend with the process group. Args: - backend: The distributed backend to use (e.g., 'hccl') - prefix_store: The prefix store for distributed coordination - group_rank: The rank of the current process within the group - group_size: The total number of processes in the group - timeout: Maximum time to wait for the operation to complete - **kwargs: Additional backend-specific options - warning: - Uses internal PyTorch API (torch._C._distributed_c10d.ProcessGroupHCCL) - which may change in future PyTorch versions. Compatibility should be - verified with each PyTorch upgrade. + backend: The distributed backend to use. Currently only 'hccl' is + supported for Ascend NPUs. + prefix_store: The prefix store for distributed coordination. + group_rank: The rank of the current process within the group. + group_size: The total number of processes in the group. + timeout: Maximum time to wait for the operation to complete. + Warning: + Uses internal PyTorch NPU API (torch_npu._C._distributed_c10d.ProcessGroupHCCL) + which may change in future PyTorch / torch_npu versions. Compatibility + should be verified with each upgrade. Compatibility Risk: - - High risk of breakage in PyTorch 2.4+ - - No semantic versioning guarantees - - Requires testing with new PyTorch releases + - High risk of breakage in future PyTorch / torch_npu releases. + - No semantic versioning guarantees for internal APIs. + - Requires testing with new PyTorch / torch_npu releases. Returns: - A ProcessGroup object configured with the specified backend + A ProcessGroup object configured with the specified backend. """ # INTERNAL API USAGE - COMPATIBILITY RISK From a0f57c3237d436a57f921042991955050c0a799e Mon Sep 17 00:00:00 2001 From: nifeng <1542305589@qq.com> Date: Tue, 7 Apr 2026 20:36:32 +0800 Subject: [PATCH 3/4] Use custom weight transfer patch in AscendElasticEPScalingExecutor.transfer_weights and receive_weights Signed-off-by: nifeng <1542305589@qq.com> --- .../distributed/elastic_ep/elastic_execute.py | 123 ++++++++---------- 1 file changed, 55 insertions(+), 68 deletions(-) diff --git a/vllm_ascend/distributed/elastic_ep/elastic_execute.py b/vllm_ascend/distributed/elastic_ep/elastic_execute.py index 94653efdab4..2cb1529cc76 100644 --- a/vllm_ascend/distributed/elastic_ep/elastic_execute.py +++ b/vllm_ascend/distributed/elastic_ep/elastic_execute.py @@ -1,11 +1,49 @@ +# NOTE: +# This file is adapted from vLLM's elastic_execute.py +# +# Key differences: +# 1. Device-specific adaptations: Replaces CUDA-specific operations with NPU (Ascend) equivalents +# - Uses `torch_npu` instead of CUDA APIs +# - Replaces `torch.accelerator.synchronize()` with `torch.npu.synchronize()` +# - Replaces `torch.accelerator.empty_cache()` with `torch.npu.empty_cache()` +# - Uses `ACLGraphWrapper` instead of `CUDAGraphWrapper` for graph management +# +# 2. Custom weight transfer implementation: Implements `ascend_batch_transfer_weights()` +# - Adds support for quantized weight names (aclnn_input_scale, aclnn_input_scale_reciprocal, aclnn_input_offset) +# - Uses threading lock (`_PATCH_LOCK`) for thread-safe weight transfer patching +# +# 3. Enhanced broadcast_expert_mapping: Simplified signature and implementation +# - Removed `physical_to_logical`, `num_local_physical_experts`, `num_logical_experts` parameters +# - Uses `expert_maps` tensor directly for broadcasting +# +# 4. Extended AscendElasticEPScalingExecutor class: +# - Adds `_use_ascend_transfer_impl()` context manager for patching weight transfer +# - Implements `_release_acl_graphs()` to clear ACL graphs instead of CUDA graphs +# - Adds `_replace_ascend_active_groups()` calls for Ascend-specific group management +# - Integrates with `create_ascend_standby_groups()` and `pop_ascend_standby_groups()` +# - Adds support for Ascend-specific MoE modules (AscendFusedMoE, AscendSharedFusedMoE) +# - Handles Ascend-specific quantization method (AscendW8A8DynamicFusedMoEMethod) +# - Integrates with `get_mc2_group()` and `get_dynamic_eplb_group()` for Ascend communication +# - Adds `setup_moe_comm_method()` calls for MoE communication setup +# +# 5. EPLB (Expert Parallel Load Balancing) adaptations: +# - Uses `eplb_loader`, `eplb_adaptor`, `eplb_updator` from model_runner +# - Implements `_perform_eplb_reshuffle()` with expert resharding logic +# - Handles dynamic EPLB configuration via `get_ascend_config().eplb_config` +# +# ============================================================ + import copy import gc +import threading from collections.abc import Iterable, Sequence +from contextlib import contextmanager import numpy as np import torch import torch.nn as nn import torch_npu +import vllm.distributed.elastic_ep.elastic_execute as elastic_execute_mod from torch.distributed import P2POp from vllm.compilation.counter import compilation_counter from vllm.compilation.wrapper import reset_compile_wrapper @@ -47,8 +85,10 @@ from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method from vllm_ascend.quantization.methods.w8a8_dynamic import AscendW8A8DynamicFusedMoEMethod +_PATCH_LOCK = threading.Lock() + -def batch_transfer_weights( +def ascend_batch_transfer_weights( model: nn.Module, is_sender: bool, peer_rank: int, @@ -122,8 +162,14 @@ def broadcast_expert_mapping( class AscendElasticEPScalingExecutor(ElasticEPScalingExecutor): - def __init__(self, worker): - super().__init__(worker) + @contextmanager + def _use_ascend_transfer_impl(self): + old_impl = elastic_execute_mod.batch_transfer_weights + elastic_execute_mod.batch_transfer_weights = ascend_batch_transfer_weights + try: + yield + finally: + elastic_execute_mod.batch_transfer_weights = old_impl def load_model(self) -> None: ( @@ -165,43 +211,10 @@ def create_standby_groups(self, reconfig_request: ReconfigureDistributedRequest) ) def transfer_weights(self, old_dp_size: int, new_dp_size: int) -> None: - standby_dp_group = get_standby_dp_group() - assert standby_dp_group is not None - # Broadcast old_dp_size to all workers in standby group - if standby_dp_group.rank_in_group < old_dp_size: - old_dp_size_tensor = torch.tensor([old_dp_size], dtype=torch.int64, device="cpu") - else: - old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu") - old_dp_size_tensor = standby_dp_group.tcp_store_group.broadcast(old_dp_size_tensor, 0) - - num_new_workers = new_dp_size - old_dp_size - dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank - - # Sender-receiver pairing: the first new_workers % old_dp_size - # senders get (k+1) contiguous receivers, the rest get k - # receivers. - num_dst_per_sender = num_new_workers // old_dp_size - remainder = num_new_workers % old_dp_size - - if dp_rank < remainder: - recv_begin = dp_rank * (num_dst_per_sender + 1) - recv_end = recv_begin + num_dst_per_sender + 1 - else: - recv_begin = remainder * (num_dst_per_sender + 1) + (dp_rank - remainder) * num_dst_per_sender - recv_end = recv_begin + num_dst_per_sender - - ranks_to_send = list(range(old_dp_size + recv_begin, old_dp_size + recv_end)) - model = self.worker.model_runner.get_model() - for new_worker_rank in sorted(ranks_to_send): - batch_transfer_weights( - model=model, - is_sender=True, - peer_rank=new_worker_rank, - dp_group=standby_dp_group, - expert_weights=model.expert_weights, - ) - torch.accelerator.synchronize() + model.expert_weights = [item[1] for item in self.worker.model_runner.eplb_adaptor.param_dict.items()] + with _PATCH_LOCK, self._use_ascend_transfer_impl(): + super().transfer_weights(old_dp_size=old_dp_size, new_dp_size=new_dp_size) def broadcast_expert_mapping(self): standby_dp_group = get_standby_dp_group() @@ -395,36 +408,10 @@ def perform_scale_down_eplb_reshuffle(self, new_dp_size: int) -> None: self._perform_eplb_reshuffle() def receive_weights(self) -> None: - dp_group = get_dp_group() - assert isinstance(dp_group, StatelessGroupCoordinator) - new_dp_size = dp_group.world_size - dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank - - # Receive old_dp_size broadcasted during transfer_weights - old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu") - old_dp_size_tensor = dp_group.tcp_store_group.broadcast(old_dp_size_tensor, 0) - old_dp_size = int(old_dp_size_tensor[0].item()) - - # Calculate which existing worker will send to this new worker - num_new_workers = new_dp_size - old_dp_size - new_worker_idx = dp_rank - old_dp_size - num_dst_per_sender = num_new_workers // old_dp_size - remainder = num_new_workers % old_dp_size - - if new_worker_idx < remainder * (num_dst_per_sender + 1): - sender_rank = new_worker_idx // (num_dst_per_sender + 1) - else: - sender_rank = remainder + (new_worker_idx - remainder * (num_dst_per_sender + 1)) // num_dst_per_sender - model = self.worker.model_runner.get_model() - batch_transfer_weights( - model=model, - is_sender=False, - peer_rank=sender_rank, - dp_group=dp_group, - expert_weights=model.expert_weights, - ) - torch.accelerator.synchronize() + model.expert_weights = [item[1] for item in self.worker.model_runner.eplb_adaptor.param_dict.items()] + with _PATCH_LOCK, self._use_ascend_transfer_impl(): + super().receive_weights() def receive_expert_mapping(self) -> tuple[torch.Tensor, int, int]: dp_group = get_dp_group() From dd7348c07d71f5cf49abcfdeb07983fb9794796f Mon Sep 17 00:00:00 2001 From: nifeng <1542305589@qq.com> Date: Wed, 8 Apr 2026 15:39:41 +0800 Subject: [PATCH 4/4] change MoE modules detection Signed-off-by: nifeng <1542305589@qq.com> --- vllm_ascend/distributed/elastic_ep/elastic_execute.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/distributed/elastic_ep/elastic_execute.py b/vllm_ascend/distributed/elastic_ep/elastic_execute.py index 2cb1529cc76..4ec2d847bb0 100644 --- a/vllm_ascend/distributed/elastic_ep/elastic_execute.py +++ b/vllm_ascend/distributed/elastic_ep/elastic_execute.py @@ -66,7 +66,7 @@ from vllm.distributed.parallel_state import _replace_active_groups from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator from vllm.logger import logger -from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig +from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig, FusedMoE from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper from vllm.v1.worker.workspace import lock_workspace, unlock_workspace @@ -277,7 +277,7 @@ def switch_and_prepare(self) -> None: moe_modules = [ module for module in self.worker.model_runner.model.modules() - if (module.__class__.__name__ == "AscendFusedMoE" or module.__class__.__name__ == "AscendSharedFusedMoE") + if isinstance(module, FusedMoE) ] num_local_experts = moe_modules[0].moe_config.num_local_experts assert all(module.moe_config.num_local_experts == num_local_experts for module in moe_modules), ( @@ -430,7 +430,7 @@ def prepare_new_worker(self) -> None: moe_modules = [ module for module in self.worker.model_runner.model.modules() - if (module.__class__.__name__ == "AscendFusedMoE" or module.__class__.__name__ == "AscendSharedFusedMoE") + if isinstance(module, FusedMoE) ] for module in moe_modules: with set_current_vllm_config(self.worker.vllm_config):