3333
3434from  vllm .attention  import  Attention 
3535from  vllm .compilation .decorators  import  support_torch_compile 
36- from  vllm .config  import  CacheConfig , VllmConfig 
37- from  vllm .distributed  import  get_pp_group , get_tensor_model_parallel_world_size 
36+ from  vllm .config  import  CacheConfig , VllmConfig , get_current_vllm_config 
37+ from  vllm .distributed  import  (
38+     get_ep_group ,
39+     get_pp_group ,
40+     get_tensor_model_parallel_world_size ,
41+ )
3842from  vllm .logger  import  init_logger 
3943from  vllm .model_executor .layers .activation  import  SiluAndMul 
4044from  vllm .model_executor .layers .fused_moe  import  SharedFusedMoE 
5862)
5963from  vllm .sequence  import  IntermediateTensors 
6064
61- from  .interfaces  import  SupportsLoRA , SupportsPP 
65+ from  .interfaces  import  MixtureOfExperts ,  SupportsLoRA , SupportsPP 
6266from  .utils  import  (
6367    AutoWeightsLoader ,
6468    PPMissingLayer ,
@@ -118,12 +122,34 @@ def __init__(
118122        config : PretrainedConfig ,
119123        quant_config : Optional [QuantizationConfig ] =  None ,
120124        prefix : str  =  "" ,
125+         enable_eplb : bool  =  False ,
121126    ):
122127        super ().__init__ ()
123128
124129        layer_idx  =  extract_layer_index (prefix )
125130        self .layer_idx  =  layer_idx 
126131        self .tp_size  =  get_tensor_model_parallel_world_size ()
132+ 
133+         self .moe_num_shared_experts  =  getattr (config , "moe_num_shared_experts" , None )
134+         self .ep_group  =  get_ep_group ().device_group 
135+         self .ep_rank  =  self .ep_group .rank ()
136+         self .ep_size  =  self .ep_group .size ()
137+         self .n_routed_experts : int  =  config .moe_num_experts 
138+         self .n_shared_experts : int  =  self .moe_num_shared_experts 
139+ 
140+         # Load balancing settings. 
141+         vllm_config  =  get_current_vllm_config ()
142+         parallel_config  =  vllm_config .parallel_config 
143+         self .enable_eplb  =  enable_eplb 
144+ 
145+         self .n_redundant_experts  =  parallel_config .num_redundant_experts 
146+         self .n_logical_experts  =  self .n_routed_experts 
147+         self .n_physical_experts  =  self .n_logical_experts  +  self .n_redundant_experts 
148+         self .n_local_physical_experts  =  self .n_physical_experts  //  self .ep_size 
149+         self .physical_expert_start  =  self .ep_rank  *  self .n_local_physical_experts 
150+         self .physical_expert_end  =  (
151+             self .physical_expert_start  +  self .n_local_physical_experts 
152+         )
127153        self .has_shared_experts  =  getattr (config , "moe_num_shared_experts" , 0 ) >  0 
128154
129155        if  self .tp_size  >  config .moe_num_experts :
@@ -171,6 +197,8 @@ def __init__(
171197            quant_config = quant_config ,
172198            prefix = f"{ prefix }  ,
173199            e_score_correction_bias = self .gate .e_score_correction_bias ,
200+             enable_eplb = self .enable_eplb ,
201+             num_redundant_experts = self .n_redundant_experts ,
174202        )
175203
176204    def  forward (self , hidden_states : torch .Tensor ) ->  torch .Tensor :
@@ -298,6 +326,7 @@ def __init__(
298326        cache_config : Optional [CacheConfig ] =  None ,
299327        quant_config : Optional [QuantizationConfig ] =  None ,
300328        prefix : str  =  "" ,
329+         enable_eplb : bool  =  False ,
301330    ) ->  None :
302331        super ().__init__ ()
303332        self .hidden_size  =  config .hidden_size 
@@ -338,7 +367,10 @@ def __init__(
338367            and  layer_idx  <=  moe_layer_end_index 
339368        ):
340369            self .mlp  =  Ernie4_5_MoeMoE (
341-                 config = config , quant_config = quant_config , prefix = f"{ prefix }  
370+                 config = config ,
371+                 quant_config = quant_config ,
372+                 prefix = f"{ prefix }  ,
373+                 enable_eplb = enable_eplb ,
342374            )
343375        else :
344376            self .mlp  =  Ernie4_5_MoeMLP (
@@ -393,6 +425,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
393425        self .padding_idx  =  config .pad_token_id 
394426        self .vocab_size  =  config .vocab_size 
395427        self .config  =  config 
428+         parallel_config  =  vllm_config .parallel_config 
429+         enable_eplb  =  parallel_config .enable_eplb 
430+         self .num_redundant_experts  =  parallel_config .num_redundant_experts 
396431
397432        if  get_pp_group ().is_first_rank :
398433            self .embed_tokens  =  VocabParallelEmbedding (
@@ -411,6 +446,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
411446                cache_config = cache_config ,
412447                quant_config = quant_config ,
413448                prefix = prefix ,
449+                 enable_eplb = enable_eplb ,
414450            ),
415451            prefix = f"{ prefix }  ,
416452        )
@@ -465,6 +501,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
465501            ckpt_down_proj_name = "down_proj" ,
466502            ckpt_up_proj_name = "up_proj" ,
467503            num_experts = self .config .moe_num_experts ,
504+             num_redundant_experts = self .num_redundant_experts ,
468505        )
469506
470507    def  load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]) ->  set [str ]:
@@ -513,15 +550,22 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
513550                weight_loader (param , loaded_weight , shard_id )
514551                break 
515552            else :
553+                 is_expert_weight  =  False 
516554                for  mapping  in  expert_params_mapping :
517555                    param_name , weight_name , expert_id , shard_id  =  mapping 
518556
519557                    if  weight_name  not  in name :
520558                        continue 
521559
522-                     name  =  name .replace (weight_name , param_name )
560+                     # Anyway, this is an expert weight and should not be 
561+                     # attempted to load as other weights later 
562+                     is_expert_weight  =  True 
563+ 
564+                     # Do not modify `name` since the loop may continue here 
565+                     # Instead, create a new variable 
566+                     name_mapped  =  name .replace (weight_name , param_name )
523567                    # Skip layers on other devices. 
524-                     if  is_pp_missing_parameter (name , self ):
568+                     if  is_pp_missing_parameter (name_mapped , self ):
525569                        continue 
526570
527571                    # Skip loading extra bias for GPTQ models. 
@@ -541,6 +585,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
541585                    )
542586                    break 
543587                else :
588+                     if  is_expert_weight :
589+                         # We've checked that this is an expert weight 
590+                         # However it's not mapped locally to this rank 
591+                         # So we simply skip it 
592+                         continue 
593+ 
544594                    # Skip loading extra bias for GPTQ models. 
545595                    if  (
546596                        name .endswith (".bias" ) or  name .endswith ("_bias" )
@@ -563,7 +613,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
563613        return  loaded_params 
564614
565615
566- class  Ernie4_5_MoeForCausalLM (nn .Module , SupportsPP , SupportsLoRA ):
616+ class  Ernie4_5_MoeForCausalLM (nn .Module , SupportsPP , SupportsLoRA ,  MixtureOfExperts ):
567617    packed_modules_mapping  =  {
568618        "qkv_proj" : [
569619            "q_proj" ,
@@ -605,6 +655,81 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
605655            self .model .make_empty_intermediate_tensors 
606656        )
607657
658+         self .expert_weights  =  []
659+ 
660+         # Set MoE hyperparameters 
661+         moe_layers_indices  =  [
662+             i 
663+             for  i  in  range (config .num_hidden_layers )
664+             if  (
665+                 i  >=  config .moe_layer_start_index 
666+                 and  i  <=  config .moe_layer_end_index 
667+                 and  (i  +  1 ) %  config .moe_layer_interval  ==  0 
668+             )
669+         ]
670+         self .num_moe_layers  =  len (moe_layers_indices )
671+         self .num_expert_groups  =  1 
672+ 
673+         self .moe_layers : list [SharedFusedMoE ] =  []
674+         example_moe  =  None 
675+         for  layer  in  self .model .layers :
676+             if  isinstance (layer , PPMissingLayer ):
677+                 continue 
678+ 
679+             assert  isinstance (layer , Ernie4_5_MoeDecoderLayer )
680+             if  isinstance (layer .mlp , Ernie4_5_MoeMoE ):
681+                 example_moe  =  layer .mlp 
682+                 self .moe_layers .append (layer .mlp .experts )
683+ 
684+         if  example_moe  is  None :
685+             logger .warning ("No Ernie4_5_MoeMoE layer found in model.layers." )
686+             self .num_logical_experts  =  0 
687+             self .num_physical_experts  =  0 
688+             self .num_local_physical_experts  =  0 
689+             self .num_routed_experts  =  0 
690+             self .num_shared_experts  =  0 
691+             self .num_redundant_experts  =  0 
692+         else :
693+             self .num_logical_experts  =  example_moe .n_logical_experts 
694+             self .num_physical_experts  =  example_moe .n_physical_experts 
695+             self .num_local_physical_experts  =  example_moe .n_local_physical_experts 
696+             self .num_routed_experts  =  example_moe .n_routed_experts 
697+             self .num_shared_experts  =  example_moe .n_shared_experts 
698+             self .num_redundant_experts  =  example_moe .n_redundant_experts 
699+ 
700+     def  set_eplb_state (
701+         self ,
702+         expert_load_view : torch .Tensor ,
703+         logical_to_physical_map : torch .Tensor ,
704+         logical_replica_count : torch .Tensor ,
705+     ) ->  None :
706+         for  layer_idx , layer  in  enumerate (self .moe_layers ):
707+             # Register the expert weights. 
708+             self .expert_weights .append (layer .get_expert_weights ())
709+             layer .set_eplb_state (
710+                 moe_layer_idx = layer_idx ,
711+                 expert_load_view = expert_load_view ,
712+                 logical_to_physical_map = logical_to_physical_map ,
713+                 logical_replica_count = logical_replica_count ,
714+             )
715+ 
716+     def  update_physical_experts_metadata (
717+         self ,
718+         num_physical_experts : int ,
719+         num_local_physical_experts : int ,
720+     ) ->  None :
721+         assert  self .num_local_physical_experts  ==  num_local_physical_experts 
722+         self .num_physical_experts  =  num_physical_experts 
723+         self .num_local_physical_experts  =  num_local_physical_experts 
724+         self .num_redundant_experts  =  num_physical_experts  -  self .num_logical_experts 
725+         for  layer  in  self .model .layers :
726+             if  isinstance (layer .mlp , Ernie4_5_MoeMoE ):
727+                 moe  =  layer .mlp 
728+                 moe .n_local_physical_experts  =  num_local_physical_experts 
729+                 moe .n_physical_experts  =  num_physical_experts 
730+                 moe .n_redundant_experts  =  self .num_redundant_experts 
731+                 moe .experts .update_expert_map ()
732+ 
608733    def  get_input_embeddings (self , input_ids : torch .Tensor ) ->  torch .Tensor :
609734        return  self .model .get_input_embeddings (input_ids )
610735
0 commit comments