Skip to content

Commit 032dabb

Browse files
HsChen-sysHaisheng Chen
authored andcommitted
[EPLB] Support ernie4.5-moe (vllm-project#22100)
Signed-off-by: Haisheng Chen <[email protected]> Signed-off-by: Haisheng Chen <[email protected]> Signed-off-by: Haisheng Chen <[email protected]> Co-authored-by: Haisheng Chen <[email protected]> Signed-off-by: 0xrushi <[email protected]>
1 parent 8935115 commit 032dabb

File tree

1 file changed

+132
-7
lines changed

1 file changed

+132
-7
lines changed

vllm/model_executor/models/ernie45_moe.py

Lines changed: 132 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,12 @@
3333

3434
from vllm.attention import Attention
3535
from vllm.compilation.decorators import support_torch_compile
36-
from vllm.config import CacheConfig, VllmConfig
37-
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
36+
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
37+
from vllm.distributed import (
38+
get_ep_group,
39+
get_pp_group,
40+
get_tensor_model_parallel_world_size,
41+
)
3842
from vllm.logger import init_logger
3943
from vllm.model_executor.layers.activation import SiluAndMul
4044
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
@@ -58,7 +62,7 @@
5862
)
5963
from vllm.sequence import IntermediateTensors
6064

61-
from .interfaces import SupportsLoRA, SupportsPP
65+
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
6266
from .utils import (
6367
AutoWeightsLoader,
6468
PPMissingLayer,
@@ -118,12 +122,34 @@ def __init__(
118122
config: PretrainedConfig,
119123
quant_config: Optional[QuantizationConfig] = None,
120124
prefix: str = "",
125+
enable_eplb: bool = False,
121126
):
122127
super().__init__()
123128

124129
layer_idx = extract_layer_index(prefix)
125130
self.layer_idx = layer_idx
126131
self.tp_size = get_tensor_model_parallel_world_size()
132+
133+
self.moe_num_shared_experts = getattr(config, "moe_num_shared_experts", None)
134+
self.ep_group = get_ep_group().device_group
135+
self.ep_rank = self.ep_group.rank()
136+
self.ep_size = self.ep_group.size()
137+
self.n_routed_experts: int = config.moe_num_experts
138+
self.n_shared_experts: int = self.moe_num_shared_experts
139+
140+
# Load balancing settings.
141+
vllm_config = get_current_vllm_config()
142+
parallel_config = vllm_config.parallel_config
143+
self.enable_eplb = enable_eplb
144+
145+
self.n_redundant_experts = parallel_config.num_redundant_experts
146+
self.n_logical_experts = self.n_routed_experts
147+
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
148+
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
149+
self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
150+
self.physical_expert_end = (
151+
self.physical_expert_start + self.n_local_physical_experts
152+
)
127153
self.has_shared_experts = getattr(config, "moe_num_shared_experts", 0) > 0
128154

129155
if self.tp_size > config.moe_num_experts:
@@ -171,6 +197,8 @@ def __init__(
171197
quant_config=quant_config,
172198
prefix=f"{prefix}.experts",
173199
e_score_correction_bias=self.gate.e_score_correction_bias,
200+
enable_eplb=self.enable_eplb,
201+
num_redundant_experts=self.n_redundant_experts,
174202
)
175203

176204
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -298,6 +326,7 @@ def __init__(
298326
cache_config: Optional[CacheConfig] = None,
299327
quant_config: Optional[QuantizationConfig] = None,
300328
prefix: str = "",
329+
enable_eplb: bool = False,
301330
) -> None:
302331
super().__init__()
303332
self.hidden_size = config.hidden_size
@@ -338,7 +367,10 @@ def __init__(
338367
and layer_idx <= moe_layer_end_index
339368
):
340369
self.mlp = Ernie4_5_MoeMoE(
341-
config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
370+
config=config,
371+
quant_config=quant_config,
372+
prefix=f"{prefix}.mlp",
373+
enable_eplb=enable_eplb,
342374
)
343375
else:
344376
self.mlp = Ernie4_5_MoeMLP(
@@ -393,6 +425,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
393425
self.padding_idx = config.pad_token_id
394426
self.vocab_size = config.vocab_size
395427
self.config = config
428+
parallel_config = vllm_config.parallel_config
429+
enable_eplb = parallel_config.enable_eplb
430+
self.num_redundant_experts = parallel_config.num_redundant_experts
396431

397432
if get_pp_group().is_first_rank:
398433
self.embed_tokens = VocabParallelEmbedding(
@@ -411,6 +446,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
411446
cache_config=cache_config,
412447
quant_config=quant_config,
413448
prefix=prefix,
449+
enable_eplb=enable_eplb,
414450
),
415451
prefix=f"{prefix}.layers",
416452
)
@@ -465,6 +501,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
465501
ckpt_down_proj_name="down_proj",
466502
ckpt_up_proj_name="up_proj",
467503
num_experts=self.config.moe_num_experts,
504+
num_redundant_experts=self.num_redundant_experts,
468505
)
469506

470507
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
@@ -513,15 +550,22 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
513550
weight_loader(param, loaded_weight, shard_id)
514551
break
515552
else:
553+
is_expert_weight = False
516554
for mapping in expert_params_mapping:
517555
param_name, weight_name, expert_id, shard_id = mapping
518556

519557
if weight_name not in name:
520558
continue
521559

522-
name = name.replace(weight_name, param_name)
560+
# Anyway, this is an expert weight and should not be
561+
# attempted to load as other weights later
562+
is_expert_weight = True
563+
564+
# Do not modify `name` since the loop may continue here
565+
# Instead, create a new variable
566+
name_mapped = name.replace(weight_name, param_name)
523567
# Skip layers on other devices.
524-
if is_pp_missing_parameter(name, self):
568+
if is_pp_missing_parameter(name_mapped, self):
525569
continue
526570

527571
# Skip loading extra bias for GPTQ models.
@@ -541,6 +585,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
541585
)
542586
break
543587
else:
588+
if is_expert_weight:
589+
# We've checked that this is an expert weight
590+
# However it's not mapped locally to this rank
591+
# So we simply skip it
592+
continue
593+
544594
# Skip loading extra bias for GPTQ models.
545595
if (
546596
name.endswith(".bias") or name.endswith("_bias")
@@ -563,7 +613,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
563613
return loaded_params
564614

565615

566-
class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
616+
class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExperts):
567617
packed_modules_mapping = {
568618
"qkv_proj": [
569619
"q_proj",
@@ -605,6 +655,81 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
605655
self.model.make_empty_intermediate_tensors
606656
)
607657

658+
self.expert_weights = []
659+
660+
# Set MoE hyperparameters
661+
moe_layers_indices = [
662+
i
663+
for i in range(config.num_hidden_layers)
664+
if (
665+
i >= config.moe_layer_start_index
666+
and i <= config.moe_layer_end_index
667+
and (i + 1) % config.moe_layer_interval == 0
668+
)
669+
]
670+
self.num_moe_layers = len(moe_layers_indices)
671+
self.num_expert_groups = 1
672+
673+
self.moe_layers: list[SharedFusedMoE] = []
674+
example_moe = None
675+
for layer in self.model.layers:
676+
if isinstance(layer, PPMissingLayer):
677+
continue
678+
679+
assert isinstance(layer, Ernie4_5_MoeDecoderLayer)
680+
if isinstance(layer.mlp, Ernie4_5_MoeMoE):
681+
example_moe = layer.mlp
682+
self.moe_layers.append(layer.mlp.experts)
683+
684+
if example_moe is None:
685+
logger.warning("No Ernie4_5_MoeMoE layer found in model.layers.")
686+
self.num_logical_experts = 0
687+
self.num_physical_experts = 0
688+
self.num_local_physical_experts = 0
689+
self.num_routed_experts = 0
690+
self.num_shared_experts = 0
691+
self.num_redundant_experts = 0
692+
else:
693+
self.num_logical_experts = example_moe.n_logical_experts
694+
self.num_physical_experts = example_moe.n_physical_experts
695+
self.num_local_physical_experts = example_moe.n_local_physical_experts
696+
self.num_routed_experts = example_moe.n_routed_experts
697+
self.num_shared_experts = example_moe.n_shared_experts
698+
self.num_redundant_experts = example_moe.n_redundant_experts
699+
700+
def set_eplb_state(
701+
self,
702+
expert_load_view: torch.Tensor,
703+
logical_to_physical_map: torch.Tensor,
704+
logical_replica_count: torch.Tensor,
705+
) -> None:
706+
for layer_idx, layer in enumerate(self.moe_layers):
707+
# Register the expert weights.
708+
self.expert_weights.append(layer.get_expert_weights())
709+
layer.set_eplb_state(
710+
moe_layer_idx=layer_idx,
711+
expert_load_view=expert_load_view,
712+
logical_to_physical_map=logical_to_physical_map,
713+
logical_replica_count=logical_replica_count,
714+
)
715+
716+
def update_physical_experts_metadata(
717+
self,
718+
num_physical_experts: int,
719+
num_local_physical_experts: int,
720+
) -> None:
721+
assert self.num_local_physical_experts == num_local_physical_experts
722+
self.num_physical_experts = num_physical_experts
723+
self.num_local_physical_experts = num_local_physical_experts
724+
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
725+
for layer in self.model.layers:
726+
if isinstance(layer.mlp, Ernie4_5_MoeMoE):
727+
moe = layer.mlp
728+
moe.n_local_physical_experts = num_local_physical_experts
729+
moe.n_physical_experts = num_physical_experts
730+
moe.n_redundant_experts = self.num_redundant_experts
731+
moe.experts.update_expert_map()
732+
608733
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
609734
return self.model.get_input_embeddings(input_ids)
610735

0 commit comments

Comments
 (0)