@@ -192,7 +192,7 @@ def __init__(
192192            self .use_low_precision_combine  =  (os .environ .get (
193193                "TRTLLM_MOE_USE_LOW_PRECISION_COMBINE" , "0" )
194194                                              ==  "1" ) and  qm .has_nvfp4 ()
195-              
195+ 
196196            if  self .alltoall_method_type  ==  AlltoallMethodType .MNNVL :
197197                MnnvlMemory .initialize ()
198198                self .alltoall_workspace  =  MnnvlMoe .get_moe_workspaces (
@@ -296,6 +296,9 @@ def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int:
296296                1 ) //  self .moe_max_num_tokens 
297297
298298    def  can_use_alltoall (self , all_rank_num_tokens , all_rank_max_num_tokens ):
299+         if  self .alltoall_method_type  ==  AlltoallMethodType .MNNVL :
300+             return  True 
301+ 
299302        # Disable alltoall when chunking is used 
300303        if  self .calculate_num_chunks (all_rank_num_tokens ) >  1 :
301304            return  False 
@@ -453,12 +456,12 @@ def forward_chunk(
453456        else :
454457            tuner_num_tokens  =  None 
455458            tuner_top_k  =  None 
459+         alltoall_info  =  None 
456460        if  use_all_to_all :
457461            if  self .alltoall_method_type  ==  AlltoallMethodType .MNNVL :
458462                if  self .enable_dummy_allreduce :
459463                    self .dummy_allreduce ()
460464                token_count  =  x .shape [0 ]
461-                 alltoall_info  =  None 
462465                if  is_last_call  and  self .layer_load_balancer  is  not None  and  not  self .layer_load_balancer .is_static_routing (
463466                ):
464467                    loadbalancer_local_statistic_info  =  self .layer_load_balancer .get_local_statistic_tensor (
@@ -469,7 +472,7 @@ def forward_chunk(
469472                    self .alltoall_prepare (all_rank_max_num_tokens ,
470473                                          token_selected_slots ,
471474                                          loadbalancer_local_statistic_info )
472-                                            
475+ 
473476                if  gathered_loadbalancer_local_statistic_info  is  not None :
474477                    gathered_loadbalancer_local_statistic_info  =  gathered_loadbalancer_local_statistic_info .view (
475478                        (self .mapping .moe_ep_size , self .num_experts ))
@@ -577,10 +580,13 @@ def forward_chunk(
577580        if  self .alltoall_method_type  ==  AlltoallMethodType .MNNVL :
578581            top_k  =  self .routing_method .experts_per_token 
579582            x , x_sf , token_selected_slots , token_final_scales  =  self .alltoall_dispatch (
580-                 x , x_sf , token_selected_slots , token_final_scales , all_rank_max_num_tokens , top_k , alltoall_info )
583+                 x , x_sf , token_selected_slots , token_final_scales ,
584+                 all_rank_max_num_tokens , top_k , alltoall_info )
581585
582586        if  use_postquant_alltoall :
583-             if  self .alltoall_method_type  ==  AlltoallMethodType .DeepEP :
587+             if  self .alltoall_method_type  ==  AlltoallMethodType .MNNVL :
588+                 pass 
589+             elif  self .alltoall_method_type  ==  AlltoallMethodType .DeepEP :
584590                if  x_sf  is  not None :
585591                    # Adapter between `x_sf` and DeepEP 
586592                    # TODO: remove the adapter by adding dtype support to DeepEP 
@@ -858,34 +864,32 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
858864        self .repeat_idx  =  0  if  self .repeat_idx  ==  self .repeat_count  -  1  else  self .repeat_idx  +  1 
859865        return  outputs 
860866
861-     def  alltoall_prepare (
862-             self , all_rank_max_num_tokens : int ,
863-             token_selected_slots : torch .Tensor ,
864-             local_statistic_tensor : Optional [torch .Tensor ]):
867+     def  alltoall_prepare (self , all_rank_max_num_tokens : int ,
868+                          token_selected_slots : torch .Tensor ,
869+                          local_statistic_tensor : Optional [torch .Tensor ]):
865870        top_k  =  self .routing_method .experts_per_token 
866871
867872        alltoall_info , gathered_local_statistic_tensor  =  MnnvlMoe .mnnvl_moe_alltoallv_prepare_without_allgather (
868-             token_selected_slots ,
869-             local_statistic_tensor , self .alltoall_prepare_workspace ,
870-             all_rank_max_num_tokens , self .ep_rank , self .ep_size ,
871-             self .num_experts , self .num_slots , top_k )
873+             token_selected_slots , local_statistic_tensor ,
874+             self .alltoall_prepare_workspace , all_rank_max_num_tokens ,
875+             self .ep_rank , self .ep_size , self .num_experts , self .num_slots , top_k )
872876
873877        return  token_selected_slots , gathered_local_statistic_tensor , alltoall_info 
874878
875879    def  alltoall_dispatch (self , x : torch .Tensor , x_sf : Optional [torch .Tensor ],
876-                              token_selected_slots : torch .Tensor ,  
877-                              token_final_scales : Optional [torch .Tensor ],
878-                              all_rank_max_num_tokens : int ,
879-                              top_k :  int , 
880-                              alltoall_info :  MoEAlltoallInfo ): 
881-         
882-         x ,  x_sf ,  token_selected_slots ,  token_final_scales   =   MnnvlMoe . mnnvl_moe_alltoallv ( [x , x_sf , token_selected_slots , token_final_scales ], alltoall_info ,
883-                                           self .alltoall_workspace , self .ep_rank ,
884-                                           self . ep_size ) 
885-         
886-         torch . ops . trtllm . memset_expert_ids ( 
887-                     token_selected_slots ,  alltoall_info . recv_rank_count_cumsum ,
888-                     all_rank_max_num_tokens ,  top_k ,  self .num_slots , self .ep_size )
880+                           token_selected_slots : torch .Tensor ,
881+                           token_final_scales : Optional [torch .Tensor ],
882+                           all_rank_max_num_tokens :  int ,  top_k : int ,
883+                           alltoall_info :  MoEAlltoallInfo ): 
884+ 
885+         x ,  x_sf ,  token_selected_slots ,  token_final_scales   =   MnnvlMoe . mnnvl_moe_alltoallv ( 
886+              [x , x_sf , token_selected_slots , token_final_scales ], alltoall_info ,
887+             self .alltoall_workspace , self .ep_rank ,  self . ep_size ) 
888+ 
889+         torch . ops . trtllm . memset_expert_ids ( token_selected_slots , 
890+                                             alltoall_info . recv_rank_count_cumsum , 
891+                                             all_rank_max_num_tokens ,  top_k ,
892+                                             self .num_slots , self .ep_size )
889893
890894        return  x , x_sf , token_selected_slots , token_final_scales 
891895
0 commit comments