3333
3434from vllm .attention import Attention
3535from vllm .compilation .decorators import support_torch_compile
36- from vllm .config import CacheConfig , VllmConfig
37- from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
36+ from vllm .config import CacheConfig , VllmConfig , get_current_vllm_config
37+ from vllm .distributed import (
38+ get_ep_group ,
39+ get_pp_group ,
40+ get_tensor_model_parallel_world_size ,
41+ )
3842from vllm .logger import init_logger
3943from vllm .model_executor .layers .activation import SiluAndMul
4044from vllm .model_executor .layers .fused_moe import SharedFusedMoE
5862)
5963from vllm .sequence import IntermediateTensors
6064
61- from .interfaces import SupportsLoRA , SupportsPP
65+ from .interfaces import MixtureOfExperts , SupportsLoRA , SupportsPP
6266from .utils import (
6367 AutoWeightsLoader ,
6468 PPMissingLayer ,
@@ -118,12 +122,34 @@ def __init__(
118122 config : PretrainedConfig ,
119123 quant_config : Optional [QuantizationConfig ] = None ,
120124 prefix : str = "" ,
125+ enable_eplb : bool = False ,
121126 ):
122127 super ().__init__ ()
123128
124129 layer_idx = extract_layer_index (prefix )
125130 self .layer_idx = layer_idx
126131 self .tp_size = get_tensor_model_parallel_world_size ()
132+
133+ self .moe_num_shared_experts = getattr (config , "moe_num_shared_experts" , None )
134+ self .ep_group = get_ep_group ().device_group
135+ self .ep_rank = self .ep_group .rank ()
136+ self .ep_size = self .ep_group .size ()
137+ self .n_routed_experts : int = config .moe_num_experts
138+ self .n_shared_experts : int = self .moe_num_shared_experts
139+
140+ # Load balancing settings.
141+ vllm_config = get_current_vllm_config ()
142+ parallel_config = vllm_config .parallel_config
143+ self .enable_eplb = enable_eplb
144+
145+ self .n_redundant_experts = parallel_config .num_redundant_experts
146+ self .n_logical_experts = self .n_routed_experts
147+ self .n_physical_experts = self .n_logical_experts + self .n_redundant_experts
148+ self .n_local_physical_experts = self .n_physical_experts // self .ep_size
149+ self .physical_expert_start = self .ep_rank * self .n_local_physical_experts
150+ self .physical_expert_end = (
151+ self .physical_expert_start + self .n_local_physical_experts
152+ )
127153 self .has_shared_experts = getattr (config , "moe_num_shared_experts" , 0 ) > 0
128154
129155 if self .tp_size > config .moe_num_experts :
@@ -171,6 +197,8 @@ def __init__(
171197 quant_config = quant_config ,
172198 prefix = f"{ prefix } .experts" ,
173199 e_score_correction_bias = self .gate .e_score_correction_bias ,
200+ enable_eplb = self .enable_eplb ,
201+ num_redundant_experts = self .n_redundant_experts ,
174202 )
175203
176204 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
@@ -298,6 +326,7 @@ def __init__(
298326 cache_config : Optional [CacheConfig ] = None ,
299327 quant_config : Optional [QuantizationConfig ] = None ,
300328 prefix : str = "" ,
329+ enable_eplb : bool = False ,
301330 ) -> None :
302331 super ().__init__ ()
303332 self .hidden_size = config .hidden_size
@@ -338,7 +367,10 @@ def __init__(
338367 and layer_idx <= moe_layer_end_index
339368 ):
340369 self .mlp = Ernie4_5_MoeMoE (
341- config = config , quant_config = quant_config , prefix = f"{ prefix } .mlp"
370+ config = config ,
371+ quant_config = quant_config ,
372+ prefix = f"{ prefix } .mlp" ,
373+ enable_eplb = enable_eplb ,
342374 )
343375 else :
344376 self .mlp = Ernie4_5_MoeMLP (
@@ -393,6 +425,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
393425 self .padding_idx = config .pad_token_id
394426 self .vocab_size = config .vocab_size
395427 self .config = config
428+ parallel_config = vllm_config .parallel_config
429+ enable_eplb = parallel_config .enable_eplb
430+ self .num_redundant_experts = parallel_config .num_redundant_experts
396431
397432 if get_pp_group ().is_first_rank :
398433 self .embed_tokens = VocabParallelEmbedding (
@@ -411,6 +446,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
411446 cache_config = cache_config ,
412447 quant_config = quant_config ,
413448 prefix = prefix ,
449+ enable_eplb = enable_eplb ,
414450 ),
415451 prefix = f"{ prefix } .layers" ,
416452 )
@@ -465,6 +501,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
465501 ckpt_down_proj_name = "down_proj" ,
466502 ckpt_up_proj_name = "up_proj" ,
467503 num_experts = self .config .moe_num_experts ,
504+ num_redundant_experts = self .num_redundant_experts ,
468505 )
469506
470507 def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]) -> set [str ]:
@@ -513,15 +550,22 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
513550 weight_loader (param , loaded_weight , shard_id )
514551 break
515552 else :
553+ is_expert_weight = False
516554 for mapping in expert_params_mapping :
517555 param_name , weight_name , expert_id , shard_id = mapping
518556
519557 if weight_name not in name :
520558 continue
521559
522- name = name .replace (weight_name , param_name )
560+ # Anyway, this is an expert weight and should not be
561+ # attempted to load as other weights later
562+ is_expert_weight = True
563+
564+ # Do not modify `name` since the loop may continue here
565+ # Instead, create a new variable
566+ name_mapped = name .replace (weight_name , param_name )
523567 # Skip layers on other devices.
524- if is_pp_missing_parameter (name , self ):
568+ if is_pp_missing_parameter (name_mapped , self ):
525569 continue
526570
527571 # Skip loading extra bias for GPTQ models.
@@ -541,6 +585,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
541585 )
542586 break
543587 else :
588+ if is_expert_weight :
589+ # We've checked that this is an expert weight
590+ # However it's not mapped locally to this rank
591+ # So we simply skip it
592+ continue
593+
544594 # Skip loading extra bias for GPTQ models.
545595 if (
546596 name .endswith (".bias" ) or name .endswith ("_bias" )
@@ -563,7 +613,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
563613 return loaded_params
564614
565615
566- class Ernie4_5_MoeForCausalLM (nn .Module , SupportsPP , SupportsLoRA ):
616+ class Ernie4_5_MoeForCausalLM (nn .Module , SupportsPP , SupportsLoRA , MixtureOfExperts ):
567617 packed_modules_mapping = {
568618 "qkv_proj" : [
569619 "q_proj" ,
@@ -605,6 +655,81 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
605655 self .model .make_empty_intermediate_tensors
606656 )
607657
658+ self .expert_weights = []
659+
660+ # Set MoE hyperparameters
661+ moe_layers_indices = [
662+ i
663+ for i in range (config .num_hidden_layers )
664+ if (
665+ i >= config .moe_layer_start_index
666+ and i <= config .moe_layer_end_index
667+ and (i + 1 ) % config .moe_layer_interval == 0
668+ )
669+ ]
670+ self .num_moe_layers = len (moe_layers_indices )
671+ self .num_expert_groups = 1
672+
673+ self .moe_layers : list [SharedFusedMoE ] = []
674+ example_moe = None
675+ for layer in self .model .layers :
676+ if isinstance (layer , PPMissingLayer ):
677+ continue
678+
679+ assert isinstance (layer , Ernie4_5_MoeDecoderLayer )
680+ if isinstance (layer .mlp , Ernie4_5_MoeMoE ):
681+ example_moe = layer .mlp
682+ self .moe_layers .append (layer .mlp .experts )
683+
684+ if example_moe is None :
685+ logger .warning ("No Ernie4_5_MoeMoE layer found in model.layers." )
686+ self .num_logical_experts = 0
687+ self .num_physical_experts = 0
688+ self .num_local_physical_experts = 0
689+ self .num_routed_experts = 0
690+ self .num_shared_experts = 0
691+ self .num_redundant_experts = 0
692+ else :
693+ self .num_logical_experts = example_moe .n_logical_experts
694+ self .num_physical_experts = example_moe .n_physical_experts
695+ self .num_local_physical_experts = example_moe .n_local_physical_experts
696+ self .num_routed_experts = example_moe .n_routed_experts
697+ self .num_shared_experts = example_moe .n_shared_experts
698+ self .num_redundant_experts = example_moe .n_redundant_experts
699+
700+ def set_eplb_state (
701+ self ,
702+ expert_load_view : torch .Tensor ,
703+ logical_to_physical_map : torch .Tensor ,
704+ logical_replica_count : torch .Tensor ,
705+ ) -> None :
706+ for layer_idx , layer in enumerate (self .moe_layers ):
707+ # Register the expert weights.
708+ self .expert_weights .append (layer .get_expert_weights ())
709+ layer .set_eplb_state (
710+ moe_layer_idx = layer_idx ,
711+ expert_load_view = expert_load_view ,
712+ logical_to_physical_map = logical_to_physical_map ,
713+ logical_replica_count = logical_replica_count ,
714+ )
715+
716+ def update_physical_experts_metadata (
717+ self ,
718+ num_physical_experts : int ,
719+ num_local_physical_experts : int ,
720+ ) -> None :
721+ assert self .num_local_physical_experts == num_local_physical_experts
722+ self .num_physical_experts = num_physical_experts
723+ self .num_local_physical_experts = num_local_physical_experts
724+ self .num_redundant_experts = num_physical_experts - self .num_logical_experts
725+ for layer in self .model .layers :
726+ if isinstance (layer .mlp , Ernie4_5_MoeMoE ):
727+ moe = layer .mlp
728+ moe .n_local_physical_experts = num_local_physical_experts
729+ moe .n_physical_experts = num_physical_experts
730+ moe .n_redundant_experts = self .num_redundant_experts
731+ moe .experts .update_expert_map ()
732+
608733 def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
609734 return self .model .get_input_embeddings (input_ids )
610735
0 commit comments