2222from pytest_mock import MockerFixture
2323from vllm .model_executor .layers .fused_moe import FusedMoEMethodBase
2424
25- import vllm_ascend .ops .moe .token_dispatcher as token_dispatcher_module
2625from tests .ut .base import TestBase
27- from vllm_ascend .ascend_forward_context import (FusedMoEState ,
28- _get_fused_moe_state )
2926from vllm_ascend .ops .fused_moe import (AscendFusedMoE ,
3027 AscendUnquantizedFusedMoEMethod )
3128from vllm_ascend .ops .moe .experts_selector import select_experts
@@ -60,68 +57,24 @@ def mock_npu_format_cast(weight_data, format):
6057
6158@pytest .fixture
6259def mock_dist_env (mocker : MockerFixture ):
63- mock_setup_token_dispatchers = MagicMock ()
64- mock_token_dispatcher_with_allgather = MagicMock ()
65- mock_token_dispatcher_with_all2allv = MagicMock ()
66- mock_token_dispatcher_with_mc2 = MagicMock ()
67-
68- mock_dispatch_result_allgather = {
69- "hidden_states" : torch .randn (16 , 2 ),
70- "group_list" : torch .tensor ([8 , 16 ], dtype = torch .int64 ),
71- "group_list_type" : 0 ,
72- }
73- mock_combine_result_allgather = torch .randn (16 , 2 )
74-
75- mock_token_dispatcher_with_allgather .token_dispatch .return_value = mock_dispatch_result_allgather
76- mock_token_dispatcher_with_allgather .token_combine .return_value = mock_combine_result_allgather
77-
78- mock_dispatch_result_all2allv = {
79- "hidden_states" : torch .randn (16 , 2 ),
80- "group_list" : torch .tensor ([4 , 8 , 12 , 16 ], dtype = torch .int64 ),
81- "group_list_type" : 1 ,
82- "dynamic_scale" : None ,
83- }
84- mock_combine_result_all2allv = torch .randn (16 , 2 )
85- mock_token_dispatcher_with_all2allv .token_dispatch .return_value = mock_dispatch_result_all2allv
86- mock_token_dispatcher_with_all2allv .token_combine .return_value = mock_combine_result_all2allv
87-
88- mock_dispatch_result_mc2 = {
89- "hidden_states" : torch .randn (16 , 2 ),
90- "group_list" : torch .tensor ([5 , 10 , 15 , 16 ], dtype = torch .int64 ),
91- "group_list_type" : 1 ,
92- "dynamic_scale" : None ,
93- "assist_info_for_combine" : torch .randn (16 , 2 ),
94- "ep_recv_counts" : torch .tensor ([4 , 4 , 4 , 4 ], dtype = torch .int32 ),
95- }
96- mock_combine_result_mc2 = torch .randn (16 , 2 )
97- mock_token_dispatcher_with_mc2 .token_dispatch .return_value = mock_dispatch_result_mc2
98- mock_token_dispatcher_with_mc2 .token_combine .return_value = mock_combine_result_mc2
60+ mock_moe_comm_method = MagicMock ()
9961
100- captured_dispatchers = {}
62+ def mock_prepare (hidden_states , router_logits , ** kwargs ):
63+ return hidden_states , router_logits
10164
102- def capture_register (dispatcher_instance ):
103- key = dispatcher_instance .__class__ .__name__
104- captured_dispatchers [key ] = dispatcher_instance
105- if key == 'TokenDispatcherWithAllGather' :
106- captured_dispatchers [key ] = mock_token_dispatcher_with_allgather
107- elif key == 'TokenDispatcherWithAll2AllV' :
108- captured_dispatchers [key ] = mock_token_dispatcher_with_all2allv
109- elif key == 'TokenDispatcherWithMC2' :
110- captured_dispatchers [key ] = mock_token_dispatcher_with_mc2
65+ mock_moe_comm_method .prepare .side_effect = mock_prepare
11166
112- mock_register_token_dispatcher_patcher = patch (
113- 'vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher' ,
114- side_effect = capture_register )
67+ mock_fused_experts_result = torch .randn (16 , 2 )
68+ mock_moe_comm_method .fused_experts .return_value = mock_fused_experts_result
11569
116- mock_get_token_dispatcher_patcher = patch (
117- 'vllm_ascend.ops.moe.token_dispatcher.get_token_dispatcher' ,
118- side_effect = lambda name : captured_dispatchers .get (name ))
70+ def mock_finalize (hidden_states , ** kwargs ):
71+ return hidden_states
11972
120- default_mock_token_dispatcher = mock_token_dispatcher_with_allgather
73+ mock_moe_comm_method . finalize . side_effect = mock_finalize
12174
12275 mock_forward_context_obj = MagicMock (
123- fused_moe_state = FusedMoEState . AllGather ,
124- token_dispatcher = default_mock_token_dispatcher ,
76+ moe_comm_method = mock_moe_comm_method ,
77+ moe_comm_method_name = "mc2commimpl" ,
12578 max_tokens_across_dp = 10 ,
12679 dp_metadata = MagicMock (cu_tokens_across_dp_cpu = [5 , 10 ]),
12780 mc2_mask = torch .zeros (16 , dtype = torch .bool ),
@@ -131,14 +84,12 @@ def capture_register(dispatcher_instance):
13184 with patch ('torch.distributed.get_rank' , return_value = 0 ), \
13285 patch ('torch.distributed.get_world_size' , return_value = 4 ), \
13386 patch ('vllm_ascend.ops.fused_moe.get_ep_group' , return_value = mock_ep_and_mc2_group (mocker )), \
87+ patch ('vllm_ascend.ops.moe.token_dispatcher.get_ep_group' , return_value = mock_ep_and_mc2_group (mocker )), \
13488 patch ('vllm_ascend.ops.fused_moe.get_mc2_group' , return_value = mock_ep_and_mc2_group (mocker )), \
13589 patch ('vllm_ascend.ops.fused_moe.get_tp_group' , return_value = mock_dp_and_tp_group (mocker )), \
13690 patch ('vllm.distributed.parallel_state.get_tp_group' , return_value = mock_dp_and_tp_group (mocker )), \
13791 patch ('vllm_ascend.ops.fused_moe.get_dp_group' , return_value = mock_dp_and_tp_group (mocker )), \
13892 patch ('vllm.model_executor.layers.fused_moe.layer.get_dp_group' , return_value = mock_dp_and_tp_group (mocker )), \
139- patch ('torch.distributed.all_gather' ), \
140- patch ('torch.distributed.all_to_all_single' ), \
141- patch ('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce' ), \
14293 patch ('vllm.model_executor.layers.fused_moe.config.get_dp_group' ,
14394 return_value = mock_dp_and_tp_group (mocker )), \
14495 patch ('vllm_ascend.ops.fused_moe.get_ascend_config' ,
@@ -150,29 +101,29 @@ def capture_register(dispatcher_instance):
150101 return_value = (3 , torch .tensor ([0 , 1 , 2 , - 1 , - 1 , - 1 , - 1 , - 1 ]))), \
151102 patch ('vllm_ascend.ops.fused_moe.get_forward_context' ,
152103 return_value = mock_forward_context_obj ), \
104+ patch ('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context' ,
105+ return_value = mock_forward_context_obj ), \
153106 patch ('vllm_ascend.ops.fused_moe.get_current_vllm_config' ,
154107 return_value = MagicMock (
155108 parallel_config = MagicMock (tensor_parallel_size = 2 ),
156109 scheduler_config = MagicMock (max_num_seqs = 4 ),
157110 model_config = MagicMock (max_model_len = 2048 )
158111 )), \
159112 patch ("vllm_ascend.utils.get_ascend_soc_version" , return_value = AscendSocVersion .A3 ), \
160- patch .object (token_dispatcher_module , 'setup_token_dispatchers' , mock_setup_token_dispatchers ), \
161113 patch ('vllm_ascend.ops.moe.moe_mlp.get_forward_context' ,
162- return_value = mock_forward_context_obj ):
114+ return_value = mock_forward_context_obj ), \
115+ patch ('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher' ,
116+ return_value = None ), \
117+ patch ('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher' ,
118+ return_value = None ), \
119+ patch ('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher' ,
120+ return_value = None ):
163121
164122 yield {
165123 'mock_forward_context_obj' : mock_forward_context_obj ,
166- 'mock_token_dispatcher_with_allgather' :
167- mock_token_dispatcher_with_allgather ,
168- 'mock_token_dispatcher_with_all2allv' :
169- mock_token_dispatcher_with_all2allv ,
170- 'mock_token_dispatcher_with_mc2' : mock_token_dispatcher_with_mc2 ,
124+ 'mock_moe_comm_method' : mock_moe_comm_method ,
171125 }
172126
173- mock_register_token_dispatcher_patcher .stop ()
174- mock_get_token_dispatcher_patcher .stop ()
175-
176127
177128@pytest .fixture
178129def mock_moe_env (mocker : MockerFixture ):
@@ -338,9 +289,7 @@ def test_forward(self, mock_dist_env, default_moe_config, others_param):
338289 moe .moe_parallel_config .ep_size = 1
339290
340291 moe .quant_method = MockQuantMethod (shared_experts , num_tokens )
341- forward_context = MagicMock (mc2_mask = torch .zeros (num_tokens ,
342- dtype = torch .bool ),
343- padded_num_tokens = num_tokens )
292+ forward_context = mock_dist_env ['mock_forward_context_obj' ]
344293 with patch ("vllm_ascend.ops.fused_moe.get_forward_context" ,
345294 return_value = forward_context ):
346295 output = moe .forward (inputs ,
@@ -394,25 +343,10 @@ def test_process_weights_after_loading(self, moe_method, mock_dist_env):
394343 [[256 , 4 ], [128 , 1 ], [128 , 1 ], [128 , 4 ]])
395344 def test_apply_without_expert_map (self , moe_method , mock_dist_env ,
396345 mock_moe_env , others_param ):
397-
398346 global_num_experts , ep_size = others_param
399347 is_prefill = False
400- is_deepseek_v3_r1 = global_num_experts == 256
401-
402- if ep_size == 1 :
403- selected_token_dispatcher = mock_dist_env [
404- 'mock_token_dispatcher_with_allgather' ]
405- elif ep_size < 16 :
406- selected_token_dispatcher = mock_dist_env [
407- 'mock_token_dispatcher_with_all2allv' ]
408- else :
409- selected_token_dispatcher = mock_dist_env [
410- 'mock_token_dispatcher_with_mc2' ]
411348
412- forward_context = MagicMock (fused_moe_state = _get_fused_moe_state (
413- ep_size , is_prefill , is_deepseek_v3_r1 ),
414- with_quant = False ,
415- token_dispatcher = selected_token_dispatcher )
349+ forward_context = mock_dist_env ['mock_forward_context_obj' ]
416350
417351 with patch ("vllm_ascend.ops.fused_moe.get_forward_context" ,
418352 return_value = forward_context ):
@@ -438,35 +372,22 @@ def test_apply_without_expert_map(self, moe_method, mock_dist_env,
438372 global_num_experts = global_num_experts ,
439373 is_prefill = is_prefill )
440374
441- expected_shape = (16 , 2 )
375+ mock_moe_comm_method = mock_dist_env ['mock_moe_comm_method' ]
376+ mock_moe_comm_method .fused_experts .assert_called_once ()
442377
378+ expected_shape = (16 , 2 )
443379 assert result .shape == expected_shape
444380
445381 @pytest .mark .parametrize ("others_param" , [16 , 1 , 4 ])
446382 def test_apply_with_expert_map (self , moe_method , mock_dist_env ,
447383 mock_moe_env , others_param ):
448-
449384 ep_size = others_param
450385 is_prefill = False
451386
452- if ep_size == 1 :
453- selected_token_dispatcher = mock_dist_env [
454- 'mock_token_dispatcher_with_allgather' ]
455- elif ep_size < 16 :
456- selected_token_dispatcher = mock_dist_env [
457- 'mock_token_dispatcher_with_all2allv' ]
458- else :
459- selected_token_dispatcher = mock_dist_env [
460- 'mock_token_dispatcher_with_mc2' ]
461-
462- forward_context = MagicMock (fused_moe_state = _get_fused_moe_state (
463- ep_size , is_prefill , True ),
464- with_quant = False ,
465- token_dispatcher = selected_token_dispatcher )
387+ forward_context = mock_dist_env ['mock_forward_context_obj' ]
466388
467389 with patch ("vllm_ascend.ops.fused_moe.get_forward_context" , return_value = forward_context ), \
468390 patch ("vllm_ascend.utils.get_ascend_soc_version" , return_value = AscendSocVersion .A3 ):
469-
470391 expert_map = torch .tensor ([0 , 1 , 2 , - 1 , - 1 , - 1 , - 1 , - 1 ])
471392 moe_method .ep_size = ep_size
472393 x = torch .randn (8 , 2 , 2 )
@@ -493,8 +414,10 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
493414 expert_map = expert_map ,
494415 is_prefill = is_prefill )
495416
496- expected_shape = (16 , 2 )
417+ mock_moe_comm_method = mock_dist_env ['mock_moe_comm_method' ]
418+ mock_moe_comm_method .fused_experts .assert_called_once ()
497419
420+ expected_shape = (16 , 2 )
498421 assert result .shape == expected_shape
499422
500423
@@ -574,7 +497,7 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
574497 mock_get_forward_context ):
575498
576499 mock_forward_context = MagicMock ()
577- mock_forward_context .fused_moe_state = FusedMoEState . MC2
500+ mock_forward_context .moe_comm_method_name = "mc2commimpl"
578501 mock_get_forward_context .return_value = mock_forward_context
579502
580503 mock_is_310p .return_value = False
@@ -618,8 +541,6 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
618541 with_quant = True )
619542
620543 mock_get_forward_context .assert_called ()
621- self .assertEqual (mock_forward_context .fused_moe_state ,
622- FusedMoEState .MC2 )
623544
624545 mock_npu_dynamic_quant .assert_called ()
625546
0 commit comments