Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 41 additions & 27 deletions tests/e2e/multicard/moe/test_moe_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,30 @@

import pytest
import torch
from transformers import PretrainedConfig
from vllm import forward_context

from vllm_ascend.distributed import moe_comm_method
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
NativeAllGatherCommImpl)
from vllm.model_executor.layers.fused_moe.config import ( # isort: skip
FusedMoEConfig, FusedMoEParallelConfig)

from vllm_ascend.distributed.moe_comm_method import ( # isort: skip
AllGatherCommImpl, NativeAllGatherCommImpl)


@pytest.mark.parametrize("num_tokens", [16, 128])
@pytest.mark.parametrize("hidden_size", [64, 128])
@pytest.mark.parametrize("global_num_experts", [8, 16])
@pytest.mark.parametrize("num_local_experts", [4, 8])
@pytest.mark.parametrize("top_k_num", [2, 4])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("num_local_experts", [4, 8])
@pytest.mark.parametrize("ep_rank", [0, 1])
def test_all_gather_comm_impl(
num_tokens,
hidden_size,
global_num_experts,
num_local_experts,
top_k_num,
dtype,
num_local_experts,
ep_rank,
mocker,
):
"""
Tests the AllGatherCommImpl against the NativeAllGatherCommImpl.
Expand All @@ -56,23 +57,37 @@ def test_all_gather_comm_impl(
"num_local_experts cannot be greater than global_num_experts")

device = torch.device("npu")
hf_config = PretrainedConfig(
num_experts_per_tok=top_k_num,

# mock get_tensor_model_parallel_rank to return ep_rank
mocker.patch(
"vllm.model_executor.layers.fused_moe.config.get_tensor_model_parallel_rank",
return_value=ep_rank,
)

# make moe config
parallel_config = SimpleNamespace(
enable_expert_parallel=num_local_experts < global_num_experts)
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
tp_size_=max(2, global_num_experts // num_local_experts),
dp_size_=1,
vllm_parallel_config=parallel_config,
)

moe_config = FusedMoEConfig(
num_experts=global_num_experts,
experts_per_token=top_k_num,
hidden_dim=hidden_size,
num_local_experts=num_local_experts,
moe_parallel_config=moe_parallel_config,
in_dtype=dtype,
quant_config=None, # No quantization in this test
max_num_tokens=num_tokens,
)

# Instantiate implementations
native_impl = NativeAllGatherCommImpl(device, dtype, hf_config)

all_gather_impl = AllGatherCommImpl(device, dtype, hf_config)
native_impl = NativeAllGatherCommImpl(moe_config)

# TODO: Find out if this is the correct way to mock the forward context and ep group
# Mock get_forward_context to return an object with moe_comm_method
forward_context._forward_context = SimpleNamespace(
moe_comm_method=all_gather_impl)
# Mock get_ep_group to return a fake group with the specified ep_rank
fake_ep_group = SimpleNamespace(rank_in_group=ep_rank)
moe_comm_method.get_ep_group = lambda: fake_ep_group
all_gather_impl = AllGatherCommImpl(moe_config)

# --- Input Data ---
hidden_states = torch.randn(num_tokens,
Expand Down Expand Up @@ -103,27 +118,26 @@ def test_all_gather_comm_impl(
native_permuted_hidden,
native_expert_tokens,
_,
) = native_impl._pre_process(hidden_states, topk_ids, topk_weights,
expert_map, num_experts)
) = native_impl.permute(hidden_states, topk_ids, topk_weights, expert_map,
num_experts)
# Simulate MLP output
native_mlp_output = torch.randn_like(native_permuted_hidden)
native_impl._post_process(native_mlp_output, native_hidden_states_out)
native_impl.unpermute(native_mlp_output, native_hidden_states_out)

# --- Run AllGather Implementation ---
all_gather_hidden_states_out = hidden_states.clone()
(
all_gather_permuted_hidden,
all_gather_expert_tokens,
_,
) = torch.ops.vllm.moe_comm_pre_process(hidden_states, topk_ids,
topk_weights, expert_map,
num_experts)
) = all_gather_impl.permute(hidden_states, topk_ids, topk_weights,
expert_map, num_experts)

# Use the same simulated MLP output for a fair comparison
all_gather_mlp_output = native_mlp_output.clone()

torch.ops.vllm.moe_comm_post_process(all_gather_mlp_output,
all_gather_hidden_states_out)
all_gather_impl.unpermute(all_gather_mlp_output,
all_gather_hidden_states_out)

# --- Assertions ---
# Define tolerance based on dtype
Expand Down
68 changes: 1 addition & 67 deletions tests/ut/distributed/test_communicator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import MagicMock, patch

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -87,69 +87,3 @@ def patched_all_to_all(output_tensor_list,
output = comm.all_to_all(input_, scatter_dim=0, gather_dim=0)

assert output.tolist() == [[10, 20], [50, 60]]

@patch("vllm.config.get_current_vllm_config", return_value=None)
@patch("torch.npu.current_device", return_value=MagicMock())
@patch("torch.npu.set_device", return_value=MagicMock())
@patch("torch.distributed.get_process_group_ranks",
return_value={
0: 0,
1: 1
})
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
@patch("torch.distributed.is_initialized", return_value=True)
@patch("torch.distributed.get_rank", return_value=1)
@patch("torch.distributed.is_initialized", return_value=True)
@patch("torch.distributed.get_backend", return_value="hccl")
@patch("torch.distributed.get_rank", return_value=1)
@patch("torch.distributed.get_world_size", return_value=2)
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
@patch("torch.npu.device")
def test_dispatch(self, *_):
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
comm.all2all_manager = Mock()
hidden_states = torch.randn(2, 4, 8)
router_logits = torch.randn(2, 4, 2)

mock_dispatch_result = (torch.randn(2, 4, 8), torch.randn(2, 4, 2))
comm.all2all_manager.dispatch.return_value = mock_dispatch_result

result_hidden, result_logits = comm.dispatch(hidden_states,
router_logits)

assert torch.allclose(result_hidden, mock_dispatch_result[0])
assert torch.allclose(result_logits, mock_dispatch_result[1])

comm.all2all_manager.dispatch.assert_called_once_with(
hidden_states, router_logits)

@patch("vllm.config.get_current_vllm_config", return_value=None)
@patch("torch.npu.current_device", return_value=MagicMock())
@patch("torch.npu.set_device", return_value=MagicMock())
@patch("torch.distributed.get_process_group_ranks",
return_value={
0: 0,
1: 1
})
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
@patch("torch.distributed.is_initialized", return_value=True)
@patch("torch.distributed.get_rank", return_value=1)
@patch("torch.distributed.is_initialized", return_value=True)
@patch("torch.distributed.get_backend", return_value="hccl")
@patch("torch.distributed.get_rank", return_value=1)
@patch("torch.distributed.get_world_size", return_value=2)
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
@patch("torch.npu.device")
def test_combine(self, *_):
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
comm.all2all_manager = Mock()
hidden_states = torch.randn(2, 4, 8)

mock_combine_result = torch.randn(2, 4, 8)
comm.all2all_manager.combine.return_value = mock_combine_result

result = comm.combine(hidden_states)

assert torch.allclose(result, mock_combine_result)

comm.all2all_manager.combine.assert_called_once_with(hidden_states)
4 changes: 2 additions & 2 deletions tests/ut/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,13 @@ def test_register_ascend_customop(self, mock_ascend_rmsnorm,
# ascend custom op is not registered
utils.register_ascend_customop()
# should call register_oot three
self.assertEqual(mock_customop.register_oot.call_count, 8)
self.assertEqual(mock_customop.register_oot.call_count, 9)
self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED)

# ascend custom op is already registered
utils.register_ascend_customop()
# should not register_oot again, thus only called three in this ut
self.assertEqual(mock_customop.register_oot.call_count, 8)
self.assertEqual(mock_customop.register_oot.call_count, 9)


class TestProfileExecuteDuration(TestBase):
Expand Down
5 changes: 2 additions & 3 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
set_forward_context)

import vllm_ascend.envs as envs_ascend
from vllm_ascend.distributed.moe_comm_method import MoECommMethod


class FusedMoEState(Enum):
Expand Down Expand Up @@ -57,7 +56,7 @@ def set_ascend_forward_context(
with_prefill: bool = True,
in_profile_run: bool = False,
reserved_mc2_mask: Optional[torch.Tensor] = None,
moe_comm_method: Optional[MoECommMethod] = None,
moe_comm_method: str = "",
num_actual_tokens: Optional[int] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None):
Expand All @@ -75,7 +74,7 @@ def set_ascend_forward_context(
batch_descriptor=batch_descriptor,
):
forward_context = get_forward_context()
forward_context.moe_comm_method = moe_comm_method
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
forward_context.with_prefill = with_prefill
ep_size = (get_ep_group().world_size if
vllm_config.parallel_config.enable_expert_parallel else 1)
Expand Down
21 changes: 0 additions & 21 deletions vllm_ascend/distributed/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import torch.distributed as dist
from vllm.distributed.device_communicators.base_device_communicator import \
DeviceCommunicatorBase
from vllm.utils import logger


class NPUCommunicator(DeviceCommunicatorBase):
Expand All @@ -35,12 +34,6 @@ def __init__(self,
# init device according to rank
self.device = torch.npu.current_device()

if self.use_all2all:
from vllm.distributed.device_communicators.all2all import \
NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")

def all_to_all(self,
input_: torch.Tensor,
scatter_dim: int = 0,
Expand Down Expand Up @@ -80,17 +73,3 @@ def all_to_all(self,
dist.all_to_all(output_list, input_list, group=self.device_group)
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
return output_tensor

# TODO: Add ut for dispatch and combine
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits)
return hidden_states, router_logits

def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states)
return hidden_states
Loading
Loading