@@ -495,99 +495,3 @@ def test_token_dispatch_with_log2phy(self):
495495 self .assertIsNotNone (result ["hidden_states" ])
496496 self .assertIsNotNone (result ["group_list" ])
497497 self .assertEqual (result ["group_list_type" ], 1 )
498-
499-
500- class TestDispatcherRegistry (TestBase ):
501-
502- def setUp (self ):
503- _Dispatchers .clear ()
504-
505- def tearDown (self ):
506- _Dispatchers .clear ()
507-
508- def test_register_and_get_token_dispatcher (self ):
509- mock_dispatcher = MagicMock ()
510- mock_dispatcher .__class__ .__name__ = "MockDispatcher"
511-
512- _register_token_dispatcher (mock_dispatcher )
513-
514- self .assertIn ("MockDispatcher" , _Dispatchers )
515- self .assertIs (_Dispatchers ["MockDispatcher" ], mock_dispatcher )
516-
517- retrieved_dispatcher = get_token_dispatcher ("MockDispatcher" )
518- self .assertIs (retrieved_dispatcher , mock_dispatcher )
519-
520- self .assertIsNone (get_token_dispatcher ("NonExistentDispatcher" ))
521-
522- @patch ('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAllGather' )
523- @patch ('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher' )
524- def test_setup_token_dispatchers_ep_size_1_creates_allgather (
525- self , mock_register , mock_allgather_class ):
526- kwargs = {"top_k" : 2 , "num_experts" : 8 }
527- mock_instance = MagicMock ()
528- mock_allgather_class .return_value = mock_instance
529-
530- self .assertNotIn ("TokenDispatcherWithAllGather" , _Dispatchers )
531-
532- setup_token_dispatchers (ep_size = 1 , ** kwargs )
533-
534- mock_allgather_class .assert_called_once_with (** kwargs )
535- mock_register .assert_called_once_with (mock_instance )
536-
537- @patch ('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV' )
538- @patch ('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher' )
539- def test_setup_token_dispatchers_ep_size_2_creates_all2allv (
540- self , mock_register , mock_all2allv_class ):
541- kwargs = {"top_k" : 2 , "num_experts" : 16 , "num_local_experts" : 2 }
542- mock_instance = MagicMock ()
543- mock_all2allv_class .return_value = mock_instance
544-
545- self .assertNotIn ("TokenDispatcherWithAll2AllV" , _Dispatchers )
546-
547- setup_token_dispatchers (ep_size = 2 , ** kwargs )
548-
549- mock_all2allv_class .assert_called_once_with (** kwargs )
550- mock_register .assert_called_once_with (mock_instance )
551-
552- @patch ('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV' )
553- @patch ('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2' )
554- @patch ('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher' )
555- def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2 (
556- self , mock_register , mock_mc2_class , mock_all2allv_class ):
557- kwargs = {"top_k" : 2 , "num_experts" : 32 , "num_local_experts" : 2 }
558- mock_all2allv_instance = MagicMock ()
559- mock_mc2_instance = MagicMock ()
560- mock_all2allv_class .return_value = mock_all2allv_instance
561- mock_mc2_class .return_value = mock_mc2_instance
562-
563- self .assertNotIn ("TokenDispatcherWithAll2AllV" , _Dispatchers )
564- self .assertNotIn ("TokenDispatcherWithMC2" , _Dispatchers )
565-
566- setup_token_dispatchers (ep_size = 16 , ** kwargs )
567-
568- mock_all2allv_class .assert_called_once_with (** kwargs )
569- mock_mc2_class .assert_called_once_with (** kwargs )
570- self .assertEqual (mock_register .call_count , 2 )
571- mock_register .assert_any_call (mock_all2allv_instance )
572- mock_register .assert_any_call (mock_mc2_instance )
573-
574- @patch ('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithAll2AllV' )
575- @patch ('vllm_ascend.ops.moe.token_dispatcher.TokenDispatcherWithMC2' )
576- @patch ('vllm_ascend.ops.moe.token_dispatcher._register_token_dispatcher' )
577- def test_setup_token_dispatchers_ep_size_16_skips_if_exist (
578- self , mock_register , mock_mc2_class , mock_all2allv_class ):
579- kwargs = {"top_k" : 2 , "num_experts" : 32 , "num_local_experts" : 2 }
580- mock_existing_all2allv = MagicMock ()
581- mock_existing_mc2 = MagicMock ()
582- _Dispatchers ["TokenDispatcherWithAll2AllV" ] = mock_existing_all2allv
583- _Dispatchers ["TokenDispatcherWithMC2" ] = mock_existing_mc2
584-
585- setup_token_dispatchers (ep_size = 16 , ** kwargs )
586-
587- mock_all2allv_class .assert_not_called ()
588- mock_mc2_class .assert_not_called ()
589- mock_register .assert_not_called ()
590- self .assertIs (_Dispatchers ["TokenDispatcherWithAll2AllV" ],
591- mock_existing_all2allv )
592- self .assertIs (_Dispatchers ["TokenDispatcherWithMC2" ],
593- mock_existing_mc2 )
0 commit comments