Skip to content

Commit 0dd9668

Browse files
committed
Update test_token_dispatcher.py
1 parent c3cb5d7 commit 0dd9668

File tree

1 file changed

+0
-96
lines changed

1 file changed

+0
-96
lines changed

tests/ut/ops/test_token_dispatcher.py

Lines changed: 0 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -495,99 +495,3 @@ def test_token_dispatch_with_log2phy(self):
495495
self.assertIsNotNone(result["hidden_states"])
496496
self.assertIsNotNone(result["group_list"])
497497
self.assertEqual(result["group_list_type"], 1)
498-
499-
500-
class TestDispatcherRegistry(TestBase):
501-
502-
def setUp(self):
503-
_Dispatchers.clear()
504-
505-
def tearDown(self):
506-
_Dispatchers.clear()
507-
508-
def test_register_and_get_token_dispatcher(self):
509-
mock_dispatcher = MagicMock()
510-
mock_dispatcher.__class__.__name__ = "MockDispatcher"
511-
512-
_register_token_dispatcher(mock_dispatcher)
513-
514-
self.assertIn("MockDispatcher", _Dispatchers)
515-
self.assertIs(_Dispatchers["MockDispatcher"], mock_dispatcher)
516-
517-
retrieved_dispatcher = get_token_dispatcher("MockDispatcher")
518-
self.assertIs(retrieved_dispatcher, mock_dispatcher)
519-
520-
self.assertIsNone(get_token_dispatcher("NonExistentDispatcher"))
521-
522-
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAllGather')
523-
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
524-
def test_setup_token_dispatchers_ep_size_1_creates_allgather(
525-
self, mock_register, mock_allgather_class):
526-
kwargs = {"top_k": 2, "num_experts": 8}
527-
mock_instance = MagicMock()
528-
mock_allgather_class.return_value = mock_instance
529-
530-
self.assertNotIn("TokenDispatcherWithAllGather", _Dispatchers)
531-
532-
setup_token_dispatchers(ep_size=1, **kwargs)
533-
534-
mock_allgather_class.assert_called_once_with(**kwargs)
535-
mock_register.assert_called_once_with(mock_instance)
536-
537-
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
538-
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
539-
def test_setup_token_dispatchers_ep_size_2_creates_all2allv(
540-
self, mock_register, mock_all2allv_class):
541-
kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2}
542-
mock_instance = MagicMock()
543-
mock_all2allv_class.return_value = mock_instance
544-
545-
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
546-
547-
setup_token_dispatchers(ep_size=2, **kwargs)
548-
549-
mock_all2allv_class.assert_called_once_with(**kwargs)
550-
mock_register.assert_called_once_with(mock_instance)
551-
552-
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
553-
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2')
554-
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
555-
def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2(
556-
self, mock_register, mock_mc2_class, mock_all2allv_class):
557-
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
558-
mock_all2allv_instance = MagicMock()
559-
mock_mc2_instance = MagicMock()
560-
mock_all2allv_class.return_value = mock_all2allv_instance
561-
mock_mc2_class.return_value = mock_mc2_instance
562-
563-
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
564-
self.assertNotIn("TokenDispatcherWithMC2", _Dispatchers)
565-
566-
setup_token_dispatchers(ep_size=16, **kwargs)
567-
568-
mock_all2allv_class.assert_called_once_with(**kwargs)
569-
mock_mc2_class.assert_called_once_with(**kwargs)
570-
self.assertEqual(mock_register.call_count, 2)
571-
mock_register.assert_any_call(mock_all2allv_instance)
572-
mock_register.assert_any_call(mock_mc2_instance)
573-
574-
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV')
575-
@patch('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2')
576-
@patch('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher')
577-
def test_setup_token_dispatchers_ep_size_16_skips_if_exist(
578-
self, mock_register, mock_mc2_class, mock_all2allv_class):
579-
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
580-
mock_existing_all2allv = MagicMock()
581-
mock_existing_mc2 = MagicMock()
582-
_Dispatchers["TokenDispatcherWithAll2AllV"] = mock_existing_all2allv
583-
_Dispatchers["TokenDispatcherWithMC2"] = mock_existing_mc2
584-
585-
setup_token_dispatchers(ep_size=16, **kwargs)
586-
587-
mock_all2allv_class.assert_not_called()
588-
mock_mc2_class.assert_not_called()
589-
mock_register.assert_not_called()
590-
self.assertIs(_Dispatchers["TokenDispatcherWithAll2AllV"],
591-
mock_existing_all2allv)
592-
self.assertIs(_Dispatchers["TokenDispatcherWithMC2"],
593-
mock_existing_mc2)

0 commit comments

Comments
 (0)