Skip to content

Commit 9694820

Browse files
committed
test: Refactor MoE communication test
The test now uses the `FusedMoEConfig` for configuration instead of a generic `PretrainedConfig`. It also calls the `permute` and `unpermute` methods on the communication implementation instance, rather than calling the `torch.ops` functions directly. Signed-off-by: Yizhou Liu <[email protected]>
1 parent 91d741e commit 9694820

File tree

2 files changed

+36
-25
lines changed

2 files changed

+36
-25
lines changed

tests/e2e/multicard/moe/test_moe_comm.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,29 @@
1818

1919
import pytest
2020
import torch
21-
from transformers import PretrainedConfig
22-
from vllm import forward_context
21+
from vllm.model_executor.layers.fused_moe.config import (
22+
FusedMoEConfig, FusedMoEParallelConfig)
2323

24-
from vllm_ascend.distributed import moe_comm_method
2524
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
2625
NativeAllGatherCommImpl)
2726

2827

2928
@pytest.mark.parametrize("num_tokens", [16, 128])
3029
@pytest.mark.parametrize("hidden_size", [64, 128])
3130
@pytest.mark.parametrize("global_num_experts", [8, 16])
31+
@pytest.mark.parametrize("num_local_experts", [4, 8])
3232
@pytest.mark.parametrize("top_k_num", [2, 4])
3333
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
34-
@pytest.mark.parametrize("num_local_experts", [4, 8])
3534
@pytest.mark.parametrize("ep_rank", [0, 1])
3635
def test_all_gather_comm_impl(
3736
num_tokens,
3837
hidden_size,
3938
global_num_experts,
39+
num_local_experts,
4040
top_k_num,
4141
dtype,
42-
num_local_experts,
4342
ep_rank,
43+
mocker,
4444
):
4545
"""
4646
Tests the AllGatherCommImpl against the NativeAllGatherCommImpl.
@@ -56,23 +56,37 @@ def test_all_gather_comm_impl(
5656
"num_local_experts cannot be greater than global_num_experts")
5757

5858
device = torch.device("npu")
59-
hf_config = PretrainedConfig(
60-
num_experts_per_tok=top_k_num,
59+
60+
# mock get_tensor_model_parallel_rank to return ep_rank
61+
mocker.patch(
62+
"vllm.model_executor.layers.fused_moe.config.get_tensor_model_parallel_rank",
63+
return_value=ep_rank,
64+
)
65+
66+
# make moe config
67+
parallel_config = SimpleNamespace(
68+
enable_expert_parallel=num_local_experts < global_num_experts)
69+
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
70+
tp_size_=max(2, global_num_experts // num_local_experts),
71+
dp_size_=1,
72+
vllm_parallel_config=parallel_config,
73+
)
74+
75+
moe_config = FusedMoEConfig(
6176
num_experts=global_num_experts,
77+
experts_per_token=top_k_num,
78+
hidden_dim=hidden_size,
79+
num_local_experts=num_local_experts,
80+
moe_parallel_config=moe_parallel_config,
81+
in_dtype=dtype,
82+
quant_config=None, # No quantization in this test
83+
max_num_tokens=num_tokens,
6284
)
6385

6486
# Instantiate implementations
65-
native_impl = NativeAllGatherCommImpl(device, dtype, hf_config)
66-
67-
all_gather_impl = AllGatherCommImpl(device, dtype, hf_config)
87+
native_impl = NativeAllGatherCommImpl(moe_config)
6888

69-
# TODO: Find out if this is the correct way to mock the forward context and ep group
70-
# Mock get_forward_context to return an object with moe_comm_method
71-
forward_context._forward_context = SimpleNamespace(
72-
moe_comm_method=all_gather_impl)
73-
# Mock get_ep_group to return a fake group with the specified ep_rank
74-
fake_ep_group = SimpleNamespace(rank_in_group=ep_rank)
75-
moe_comm_method.get_ep_group = lambda: fake_ep_group
89+
all_gather_impl = AllGatherCommImpl(moe_config)
7690

7791
# --- Input Data ---
7892
hidden_states = torch.randn(num_tokens,
@@ -115,15 +129,14 @@ def test_all_gather_comm_impl(
115129
all_gather_permuted_hidden,
116130
all_gather_expert_tokens,
117131
_,
118-
) = torch.ops.vllm.moe_comm_pre_process(hidden_states, topk_ids,
119-
topk_weights, expert_map,
120-
num_experts)
132+
) = all_gather_impl.permute(hidden_states, topk_ids, topk_weights,
133+
expert_map, num_experts)
121134

122135
# Use the same simulated MLP output for a fair comparison
123136
all_gather_mlp_output = native_mlp_output.clone()
124137

125-
torch.ops.vllm.moe_comm_post_process(all_gather_mlp_output,
126-
all_gather_hidden_states_out)
138+
all_gather_impl.unpermute(all_gather_mlp_output,
139+
all_gather_hidden_states_out)
127140

128141
# --- Assertions ---
129142
# Define tolerance based on dtype

vllm_ascend/distributed/moe_comm_method.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
class MoECommMethod(ABC):
2222
"""Base class for MoE communication methods."""
2323

24-
moe_config: FusedMoEConfig = None
25-
26-
def __init__(self, moe_config: Optional[FusedMoEConfig]):
24+
def __init__(self, moe_config: FusedMoEConfig):
2725
self.moe_config = moe_config
2826

2927
@abstractmethod

0 commit comments

Comments
 (0)