Skip to content

Commit 461834a

Browse files
Adapt fused_moe to moe_comm_method in eager mode
Co-Authored-By: weijinqian0 <[email protected]> Signed-off-by: Pr0Wh1teGivee <[email protected]>
1 parent bd3dede commit 461834a

13 files changed

+387
-552
lines changed

tests/ut/ops/test_ascend_forwad_context.py

Lines changed: 0 additions & 22 deletions
This file was deleted.

tests/ut/ops/test_fused_ops.py

Lines changed: 32 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@
2222
from pytest_mock import MockerFixture
2323
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
2424

25-
import vllm_ascend.ops.moe.token_dispatcher as token_dispatcher_module
2625
from tests.ut.base import TestBase
27-
from vllm_ascend.ascend_forward_context import (FusedMoEState,
28-
_get_fused_moe_state)
2926
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
3027
AscendUnquantizedFusedMoEMethod)
3128
from vllm_ascend.ops.moe.experts_selector import select_experts
@@ -60,68 +57,24 @@ def mock_npu_format_cast(weight_data, format):
6057

6158
@pytest.fixture
6259
def mock_dist_env(mocker: MockerFixture):
63-
mock_setup_token_dispatchers = MagicMock()
64-
mock_token_dispatcher_with_allgather = MagicMock()
65-
mock_token_dispatcher_with_all2allv = MagicMock()
66-
mock_token_dispatcher_with_mc2 = MagicMock()
67-
68-
mock_dispatch_result_allgather = {
69-
"hidden_states": torch.randn(16, 2),
70-
"group_list": torch.tensor([8, 16], dtype=torch.int64),
71-
"group_list_type": 0,
72-
}
73-
mock_combine_result_allgather = torch.randn(16, 2)
74-
75-
mock_token_dispatcher_with_allgather.token_dispatch.return_value = mock_dispatch_result_allgather
76-
mock_token_dispatcher_with_allgather.token_combine.return_value = mock_combine_result_allgather
77-
78-
mock_dispatch_result_all2allv = {
79-
"hidden_states": torch.randn(16, 2),
80-
"group_list": torch.tensor([4, 8, 12, 16], dtype=torch.int64),
81-
"group_list_type": 1,
82-
"dynamic_scale": None,
83-
}
84-
mock_combine_result_all2allv = torch.randn(16, 2)
85-
mock_token_dispatcher_with_all2allv.token_dispatch.return_value = mock_dispatch_result_all2allv
86-
mock_token_dispatcher_with_all2allv.token_combine.return_value = mock_combine_result_all2allv
87-
88-
mock_dispatch_result_mc2 = {
89-
"hidden_states": torch.randn(16, 2),
90-
"group_list": torch.tensor([5, 10, 15, 16], dtype=torch.int64),
91-
"group_list_type": 1,
92-
"dynamic_scale": None,
93-
"assist_info_for_combine": torch.randn(16, 2),
94-
"ep_recv_counts": torch.tensor([4, 4, 4, 4], dtype=torch.int32),
95-
}
96-
mock_combine_result_mc2 = torch.randn(16, 2)
97-
mock_token_dispatcher_with_mc2.token_dispatch.return_value = mock_dispatch_result_mc2
98-
mock_token_dispatcher_with_mc2.token_combine.return_value = mock_combine_result_mc2
60+
mock_moe_comm_method = MagicMock()
9961

100-
captured_dispatchers = {}
62+
def mock_prepare(hidden_states, router_logits, **kwargs):
63+
return hidden_states, router_logits
10164

102-
def capture_register(dispatcher_instance):
103-
key = dispatcher_instance.__class__.__name__
104-
captured_dispatchers[key] = dispatcher_instance
105-
if key == 'TokenDispatcherWithAllGather':
106-
captured_dispatchers[key] = mock_token_dispatcher_with_allgather
107-
elif key == 'TokenDispatcherWithAll2AllV':
108-
captured_dispatchers[key] = mock_token_dispatcher_with_all2allv
109-
elif key == 'TokenDispatcherWithMC2':
110-
captured_dispatchers[key] = mock_token_dispatcher_with_mc2
65+
mock_moe_comm_method.prepare.side_effect = mock_prepare
11166

112-
mock_register_token_dispatcher_patcher = patch(
113-
'vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher',
114-
side_effect=capture_register)
67+
mock_fused_experts_result = torch.randn(16, 2)
68+
mock_moe_comm_method.fused_experts.return_value = mock_fused_experts_result
11569

116-
mock_get_token_dispatcher_patcher = patch(
117-
'vllm_ascend.ops.moe.token_dispatcher.get_token_dispatcher',
118-
side_effect=lambda name: captured_dispatchers.get(name))
70+
def mock_finalize(hidden_states, **kwargs):
71+
return hidden_states
11972

120-
default_mock_token_dispatcher = mock_token_dispatcher_with_allgather
73+
mock_moe_comm_method.finalize.side_effect = mock_finalize
12174

12275
mock_forward_context_obj = MagicMock(
123-
fused_moe_state=FusedMoEState.AllGather,
124-
token_dispatcher=default_mock_token_dispatcher,
76+
moe_comm_method=mock_moe_comm_method,
77+
moe_comm_method_name="mc2commimpl",
12578
max_tokens_across_dp=10,
12679
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
12780
mc2_mask=torch.zeros(16, dtype=torch.bool),
@@ -131,14 +84,12 @@ def capture_register(dispatcher_instance):
13184
with patch('torch.distributed.get_rank', return_value=0), \
13285
patch('torch.distributed.get_world_size', return_value=4), \
13386
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
87+
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
13488
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
13589
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
13690
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
13791
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
13892
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
139-
patch('torch.distributed.all_gather'), \
140-
patch('torch.distributed.all_to_all_single'), \
141-
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \
14293
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
14394
return_value=mock_dp_and_tp_group(mocker)), \
14495
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
@@ -150,29 +101,29 @@ def capture_register(dispatcher_instance):
150101
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
151102
patch('vllm_ascend.ops.fused_moe.get_forward_context',
152103
return_value=mock_forward_context_obj), \
104+
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
105+
return_value=mock_forward_context_obj), \
153106
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
154107
return_value=MagicMock(
155108
parallel_config=MagicMock(tensor_parallel_size=2),
156109
scheduler_config=MagicMock(max_num_seqs=4),
157110
model_config=MagicMock(max_model_len=2048)
158111
)), \
159112
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
160-
patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers), \
161113
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
162-
return_value=mock_forward_context_obj):
114+
return_value=mock_forward_context_obj), \
115+
patch('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
116+
return_value=None), \
117+
patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
118+
return_value=None), \
119+
patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
120+
return_value=None):
163121

164122
yield {
165123
'mock_forward_context_obj': mock_forward_context_obj,
166-
'mock_token_dispatcher_with_allgather':
167-
mock_token_dispatcher_with_allgather,
168-
'mock_token_dispatcher_with_all2allv':
169-
mock_token_dispatcher_with_all2allv,
170-
'mock_token_dispatcher_with_mc2': mock_token_dispatcher_with_mc2,
124+
'mock_moe_comm_method': mock_moe_comm_method,
171125
}
172126

173-
mock_register_token_dispatcher_patcher.stop()
174-
mock_get_token_dispatcher_patcher.stop()
175-
176127

177128
@pytest.fixture
178129
def mock_moe_env(mocker: MockerFixture):
@@ -338,9 +289,7 @@ def test_forward(self, mock_dist_env, default_moe_config, others_param):
338289
moe.moe_parallel_config.ep_size = 1
339290

340291
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
341-
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
342-
dtype=torch.bool),
343-
padded_num_tokens=num_tokens)
292+
forward_context = mock_dist_env['mock_forward_context_obj']
344293
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
345294
return_value=forward_context):
346295
output = moe.forward(inputs,
@@ -394,25 +343,10 @@ def test_process_weights_after_loading(self, moe_method, mock_dist_env):
394343
[[256, 4], [128, 1], [128, 1], [128, 4]])
395344
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
396345
mock_moe_env, others_param):
397-
398346
global_num_experts, ep_size = others_param
399347
is_prefill = False
400-
is_deepseek_v3_r1 = global_num_experts == 256
401-
402-
if ep_size == 1:
403-
selected_token_dispatcher = mock_dist_env[
404-
'mock_token_dispatcher_with_allgather']
405-
elif ep_size < 16:
406-
selected_token_dispatcher = mock_dist_env[
407-
'mock_token_dispatcher_with_all2allv']
408-
else:
409-
selected_token_dispatcher = mock_dist_env[
410-
'mock_token_dispatcher_with_mc2']
411348

412-
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
413-
ep_size, is_prefill, is_deepseek_v3_r1),
414-
with_quant=False,
415-
token_dispatcher=selected_token_dispatcher)
349+
forward_context = mock_dist_env['mock_forward_context_obj']
416350

417351
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
418352
return_value=forward_context):
@@ -438,35 +372,22 @@ def test_apply_without_expert_map(self, moe_method, mock_dist_env,
438372
global_num_experts=global_num_experts,
439373
is_prefill=is_prefill)
440374

441-
expected_shape = (16, 2)
375+
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
376+
mock_moe_comm_method.fused_experts.assert_called_once()
442377

378+
expected_shape = (16, 2)
443379
assert result.shape == expected_shape
444380

445381
@pytest.mark.parametrize("others_param", [16, 1, 4])
446382
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
447383
mock_moe_env, others_param):
448-
449384
ep_size = others_param
450385
is_prefill = False
451386

452-
if ep_size == 1:
453-
selected_token_dispatcher = mock_dist_env[
454-
'mock_token_dispatcher_with_allgather']
455-
elif ep_size < 16:
456-
selected_token_dispatcher = mock_dist_env[
457-
'mock_token_dispatcher_with_all2allv']
458-
else:
459-
selected_token_dispatcher = mock_dist_env[
460-
'mock_token_dispatcher_with_mc2']
461-
462-
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
463-
ep_size, is_prefill, True),
464-
with_quant=False,
465-
token_dispatcher=selected_token_dispatcher)
387+
forward_context = mock_dist_env['mock_forward_context_obj']
466388

467389
with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
468390
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
469-
470391
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
471392
moe_method.ep_size = ep_size
472393
x = torch.randn(8, 2, 2)
@@ -493,8 +414,10 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
493414
expert_map=expert_map,
494415
is_prefill=is_prefill)
495416

496-
expected_shape = (16, 2)
417+
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
418+
mock_moe_comm_method.fused_experts.assert_called_once()
497419

420+
expected_shape = (16, 2)
498421
assert result.shape == expected_shape
499422

500423

@@ -574,7 +497,7 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
574497
mock_get_forward_context):
575498

576499
mock_forward_context = MagicMock()
577-
mock_forward_context.fused_moe_state = FusedMoEState.MC2
500+
mock_forward_context.moe_comm_method_name = "mc2commimpl"
578501
mock_get_forward_context.return_value = mock_forward_context
579502

580503
mock_is_310p.return_value = False
@@ -618,8 +541,6 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
618541
with_quant=True)
619542

620543
mock_get_forward_context.assert_called()
621-
self.assertEqual(mock_forward_context.fused_moe_state,
622-
FusedMoEState.MC2)
623544

624545
mock_npu_dynamic_quant.assert_called()
625546

tests/ut/ops/test_token_dispatcher.py

Lines changed: 1 addition & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323

2424
from vllm_ascend.ops.moe.token_dispatcher import ( # isort: skip
2525
AscendSocVersion, TokenDispatcherWithAll2AllV,
26-
TokenDispatcherWithAllGather, TokenDispatcherWithMC2, _Dispatchers,
27-
_register_token_dispatcher, get_token_dispatcher, setup_token_dispatchers)
26+
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
2827

2928

3029
class TestTokenDispatcherWithMC2(TestBase):
@@ -495,99 +494,3 @@ def test_token_dispatch_with_log2phy(self):
495494
self.assertIsNotNone(result["hidden_states"])
496495
self.assertIsNotNone(result["group_list"])
497496
self.assertEqual(result["group_list_type"], 1)
498-
499-
500-
class TestDispatcherRegistry(TestBase):
501-
502-
def setUp(self):
503-
_Dispatchers.clear()
504-
505-
def tearDown(self):
506-
_Dispatchers.clear()
507-
508-
def test_register_and_get_token_dispatcher(self):
509-
mock_dispatcher = MagicMock()
510-
mock_dispatcher.__class__.__name__ = "MockDispatcher"
511-
512-
_register_token_dispatcher(mock_dispatcher)
513-
514-
self.assertIn("MockDispatcher", _Dispatchers)
515-
self.assertIs(_Dispatchers["MockDispatcher"], mock_dispatcher)
516-
517-
retrieved_dispatcher = get_token_dispatcher("MockDispatcher")
518-
self.assertIs(retrieved_dispatcher, mock_dispatcher)
519-
520-
self.assertIsNone(get_token_dispatcher("NonExistentDispatcher"))
521-
522-
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAllGather')
523-
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
524-
def test_setup_token_dispatchers_ep_size_1_creates_allgather(
525-
self, mock_register, mock_allgather_class):
526-
kwargs = {"top_k": 2, "num_experts": 8}
527-
mock_instance = MagicMock()
528-
mock_allgather_class.return_value = mock_instance
529-
530-
self.assertNotIn("TokenDispatcherWithAllGather", _Dispatchers)
531-
532-
setup_token_dispatchers(ep_size=1, **kwargs)
533-
534-
mock_allgather_class.assert_called_once_with(**kwargs)
535-
mock_register.assert_called_once_with(mock_instance)
536-
537-
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
538-
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
539-
def test_setup_token_dispatchers_ep_size_2_creates_all2allv(
540-
self, mock_register, mock_all2allv_class):
541-
kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2}
542-
mock_instance = MagicMock()
543-
mock_all2allv_class.return_value = mock_instance
544-
545-
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
546-
547-
setup_token_dispatchers(ep_size=2, **kwargs)
548-
549-
mock_all2allv_class.assert_called_once_with(**kwargs)
550-
mock_register.assert_called_once_with(mock_instance)
551-
552-
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
553-
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2')
554-
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
555-
def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2(
556-
self, mock_register, mock_mc2_class, mock_all2allv_class):
557-
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
558-
mock_all2allv_instance = MagicMock()
559-
mock_mc2_instance = MagicMock()
560-
mock_all2allv_class.return_value = mock_all2allv_instance
561-
mock_mc2_class.return_value = mock_mc2_instance
562-
563-
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
564-
self.assertNotIn("TokenDispatcherWithMC2", _Dispatchers)
565-
566-
setup_token_dispatchers(ep_size=16, **kwargs)
567-
568-
mock_all2allv_class.assert_called_once_with(**kwargs)
569-
mock_mc2_class.assert_called_once_with(**kwargs)
570-
self.assertEqual(mock_register.call_count, 2)
571-
mock_register.assert_any_call(mock_all2allv_instance)
572-
mock_register.assert_any_call(mock_mc2_instance)
573-
574-
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
575-
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2')
576-
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
577-
def test_setup_token_dispatchers_ep_size_16_skips_if_exist(
578-
self, mock_register, mock_mc2_class, mock_all2allv_class):
579-
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
580-
mock_existing_all2allv = MagicMock()
581-
mock_existing_mc2 = MagicMock()
582-
_Dispatchers["TokenDispatcherWithAll2AllV"] = mock_existing_all2allv
583-
_Dispatchers["TokenDispatcherWithMC2"] = mock_existing_mc2
584-
585-
setup_token_dispatchers(ep_size=16, **kwargs)
586-
587-
mock_all2allv_class.assert_not_called()
588-
mock_mc2_class.assert_not_called()
589-
mock_register.assert_not_called()
590-
self.assertIs(_Dispatchers["TokenDispatcherWithAll2AllV"],
591-
mock_existing_all2allv)
592-
self.assertIs(_Dispatchers["TokenDispatcherWithMC2"],
593-
mock_existing_mc2)

0 commit comments

Comments
 (0)