1818
1919import pytest
2020import torch
21- from transformers import PretrainedConfig
22- from vllm import forward_context
21+ from vllm . model_executor . layers . fused_moe . config import (
22+ FusedMoEConfig , FusedMoEParallelConfig )
2323
24- from vllm_ascend .distributed import moe_comm_method
2524from vllm_ascend .distributed .moe_comm_method import (AllGatherCommImpl ,
2625 NativeAllGatherCommImpl )
2726
2827
2928@pytest .mark .parametrize ("num_tokens" , [16 , 128 ])
3029@pytest .mark .parametrize ("hidden_size" , [64 , 128 ])
3130@pytest .mark .parametrize ("global_num_experts" , [8 , 16 ])
31+ @pytest .mark .parametrize ("num_local_experts" , [4 , 8 ])
3232@pytest .mark .parametrize ("top_k_num" , [2 , 4 ])
3333@pytest .mark .parametrize ("dtype" , [torch .bfloat16 , torch .float16 ])
34- @pytest .mark .parametrize ("num_local_experts" , [4 , 8 ])
3534@pytest .mark .parametrize ("ep_rank" , [0 , 1 ])
3635def test_all_gather_comm_impl (
3736 num_tokens ,
3837 hidden_size ,
3938 global_num_experts ,
39+ num_local_experts ,
4040 top_k_num ,
4141 dtype ,
42- num_local_experts ,
4342 ep_rank ,
43+ mocker ,
4444):
4545 """
4646 Tests the AllGatherCommImpl against the NativeAllGatherCommImpl.
@@ -56,23 +56,37 @@ def test_all_gather_comm_impl(
5656 "num_local_experts cannot be greater than global_num_experts" )
5757
5858 device = torch .device ("npu" )
59- hf_config = PretrainedConfig (
60- num_experts_per_tok = top_k_num ,
59+
60+ # mock get_tensor_model_parallel_rank to return ep_rank
61+ mocker .patch (
62+ "vllm.model_executor.layers.fused_moe.config.get_tensor_model_parallel_rank" ,
63+ return_value = ep_rank ,
64+ )
65+
66+ # make moe config
67+ parallel_config = SimpleNamespace (
68+ enable_expert_parallel = num_local_experts < global_num_experts )
69+ moe_parallel_config : FusedMoEParallelConfig = FusedMoEParallelConfig .make (
70+ tp_size_ = max (2 , global_num_experts // num_local_experts ),
71+ dp_size_ = 1 ,
72+ vllm_parallel_config = parallel_config ,
73+ )
74+
75+ moe_config = FusedMoEConfig (
6176 num_experts = global_num_experts ,
77+ experts_per_token = top_k_num ,
78+ hidden_dim = hidden_size ,
79+ num_local_experts = num_local_experts ,
80+ moe_parallel_config = moe_parallel_config ,
81+ in_dtype = dtype ,
82+ quant_config = None , # No quantization in this test
83+ max_num_tokens = num_tokens ,
6284 )
6385
6486 # Instantiate implementations
65- native_impl = NativeAllGatherCommImpl (device , dtype , hf_config )
66-
67- all_gather_impl = AllGatherCommImpl (device , dtype , hf_config )
87+ native_impl = NativeAllGatherCommImpl (moe_config )
6888
69- # TODO: Find out if this is the correct way to mock the forward context and ep group
70- # Mock get_forward_context to return an object with moe_comm_method
71- forward_context ._forward_context = SimpleNamespace (
72- moe_comm_method = all_gather_impl )
73- # Mock get_ep_group to return a fake group with the specified ep_rank
74- fake_ep_group = SimpleNamespace (rank_in_group = ep_rank )
75- moe_comm_method .get_ep_group = lambda : fake_ep_group
89+ all_gather_impl = AllGatherCommImpl (moe_config )
7690
7791 # --- Input Data ---
7892 hidden_states = torch .randn (num_tokens ,
@@ -115,15 +129,14 @@ def test_all_gather_comm_impl(
115129 all_gather_permuted_hidden ,
116130 all_gather_expert_tokens ,
117131 _ ,
118- ) = torch .ops .vllm .moe_comm_pre_process (hidden_states , topk_ids ,
119- topk_weights , expert_map ,
120- num_experts )
132+ ) = all_gather_impl .permute (hidden_states , topk_ids , topk_weights ,
133+ expert_map , num_experts )
121134
122135 # Use the same simulated MLP output for a fair comparison
123136 all_gather_mlp_output = native_mlp_output .clone ()
124137
125- torch . ops . vllm . moe_comm_post_process (all_gather_mlp_output ,
126- all_gather_hidden_states_out )
138+ all_gather_impl . unpermute (all_gather_mlp_output ,
139+ all_gather_hidden_states_out )
127140
128141 # --- Assertions ---
129142 # Define tolerance based on dtype
0 commit comments