diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index 0513307ab74..f5f01a6071d 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -32,8 +32,7 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.fused_moe.comm_utils import ( async_all_to_all, gather_from_sequence_parallel_region) -from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type, - is_hierarchical_communication_enabled) +from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type @dataclass @@ -117,10 +116,6 @@ def __init__(self, **kwargs): self.need_extra_args = ( get_ascend_device_type() == AscendDeviceType.A3) - # NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and - # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly - # improve communication performance. - self.need_expert_scale = is_hierarchical_communication_enabled() self.with_quant = False # Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute @@ -158,6 +153,7 @@ def get_dispatch_mc2_kwargs( else: quant_mode = 0 moe_expert_num = len(expert_map) + kwargs_mc2 = { "x": hidden_states, "expert_ids": topk_ids, @@ -166,8 +162,12 @@ def get_dispatch_mc2_kwargs( "moe_expert_num": moe_expert_num, "global_bs": self.global_bs, "expert_token_nums_type": 0, + "expert_scales": topk_weights.to(torch.float32), } + if get_ascend_device_type() == AscendDeviceType.A2: + kwargs_mc2["comm_alg"] = "hierarchy" + stage1_kwargs = { "scales": None, "quant_mode": quant_mode, @@ -181,11 +181,6 @@ def get_dispatch_mc2_kwargs( "tp_world_size": 1, "tp_rank_id": 0, }) - if self.need_expert_scale: - stage1_kwargs.update({ - "expert_scales": - topk_weights.to(torch.float32), - }) kwargs_mc2.update(stage1_kwargs) return kwargs_mc2 @@ -263,8 +258,12 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, "global_bs": self.global_bs, + "expand_scales": expand_scales, } + if get_ascend_device_type() == AscendDeviceType.A2: + kwargs_mc2["comm_alg"] = "hierarchy" + if self.with_quant: tp_recv_counts = torch.empty(1, dtype=torch.int32, @@ -275,7 +274,6 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, "group_ep": self.moe_all_to_all_group_name, "ep_world_size": self.ep_world_size, "ep_rank_id": self.ep_rank_id, - "expand_scales": expand_scales, } if self.enable_dispatch_v2: diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 97f8e2b66cc..51b87cfe37f 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -958,14 +958,6 @@ def calculate_dp_buffer_size() -> int: return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE) -# Currently, when in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 -# and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and -# significantly improve communication performance of MC2 ops dispatch/combine. -def is_hierarchical_communication_enabled(): - return (os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0" - and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1") - - def has_layer_idx(model_instance: torch.nn.Module) -> bool: if model_instance is None: return False