Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions vllm_ascend/ops/fused_moe/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe.comm_utils import (
async_all_to_all, gather_from_sequence_parallel_region)
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
is_hierarchical_communication_enabled)


@dataclass
Expand Down Expand Up @@ -116,6 +117,10 @@ def __init__(self, **kwargs):
self.need_extra_args = (
get_ascend_device_type() == AscendDeviceType.A3)

# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
# improve communication performance.
self.need_expert_scale = is_hierarchical_communication_enabled()
self.with_quant = False

# Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute
Expand Down Expand Up @@ -153,7 +158,6 @@ def get_dispatch_mc2_kwargs(
else:
quant_mode = 0
moe_expert_num = len(expert_map)

kwargs_mc2 = {
"x": hidden_states,
"expert_ids": topk_ids,
Expand All @@ -162,12 +166,8 @@ def get_dispatch_mc2_kwargs(
"moe_expert_num": moe_expert_num,
"global_bs": self.global_bs,
"expert_token_nums_type": 0,
"expert_scales": topk_weights.to(torch.float32),
}

if get_ascend_device_type() == AscendDeviceType.A2:
kwargs_mc2["comm_alg"] = "hierarchy"

stage1_kwargs = {
"scales": None,
"quant_mode": quant_mode,
Expand All @@ -181,6 +181,11 @@ def get_dispatch_mc2_kwargs(
"tp_world_size": 1,
"tp_rank_id": 0,
})
if self.need_expert_scale:
stage1_kwargs.update({
"expert_scales":
topk_weights.to(torch.float32),
})

kwargs_mc2.update(stage1_kwargs)
return kwargs_mc2
Expand Down Expand Up @@ -258,12 +263,8 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": self.global_bs,
"expand_scales": expand_scales,
}

if get_ascend_device_type() == AscendDeviceType.A2:
kwargs_mc2["comm_alg"] = "hierarchy"

if self.with_quant:
tp_recv_counts = torch.empty(1,
dtype=torch.int32,
Expand All @@ -274,6 +275,7 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
"group_ep": self.moe_all_to_all_group_name,
"ep_world_size": self.ep_world_size,
"ep_rank_id": self.ep_rank_id,
"expand_scales": expand_scales,
}

if self.enable_dispatch_v2:
Expand Down
8 changes: 8 additions & 0 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,14 @@ def calculate_dp_buffer_size() -> int:
return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE)


# Currently, when in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1
# and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and
# significantly improve communication performance of MC2 ops dispatch/combine.
def is_hierarchical_communication_enabled():
return (os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0"
and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1")
Comment on lines +989 to +991
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comments in both this file and token_dispatcher.py strongly suggest that this hierarchical communication feature is specifically for A2 devices. To prevent it from being accidentally enabled on other device types if the environment variables are set, which could lead to unexpected behavior or performance issues, it's safer to add an explicit device type check here.

Suggested change
def is_hierarchical_communication_enabled():
return (os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0"
and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1")
def is_hierarchical_communication_enabled():
return (get_ascend_device_type() == AscendDeviceType.A2 and
os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0" and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1")



def has_layer_idx(model_instance: torch.nn.Module) -> bool:
if model_instance is None:
return False
Expand Down
Loading