Skip to content

Commit 1784929

Browse files
xyxinyangepwalsh
authored andcommitted
[Model] Fix Ernie4.5MoE e_score_correction_bias parameter (vllm-project#21586)
Signed-off-by: zhouchong <[email protected]> Co-authored-by: zhouchong <[email protected]>
1 parent beafd5e commit 1784929

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

vllm/model_executor/models/ernie45_moe.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,19 @@ def __init__(
123123
quant_config=None,
124124
prefix=f"{prefix}.gate")
125125

126-
self.experts = FusedMoE(num_experts=config.moe_num_experts,
127-
top_k=config.moe_k,
128-
hidden_size=config.hidden_size,
129-
intermediate_size=config.moe_intermediate_size,
130-
reduce_results=False,
131-
renormalize=True,
132-
quant_config=quant_config,
133-
prefix=f"{prefix}.experts")
126+
self.gate.e_score_correction_bias = nn.Parameter(
127+
torch.empty(config.moe_num_experts))
128+
129+
self.experts = FusedMoE(
130+
num_experts=config.moe_num_experts,
131+
top_k=config.moe_k,
132+
hidden_size=config.hidden_size,
133+
intermediate_size=config.moe_intermediate_size,
134+
reduce_results=False,
135+
renormalize=True,
136+
quant_config=quant_config,
137+
prefix=f"{prefix}.experts",
138+
e_score_correction_bias=self.gate.e_score_correction_bias)
134139

135140
if self.moe_num_shared_experts is not None:
136141
intermediate_size = (config.moe_intermediate_size *
@@ -459,6 +464,10 @@ def load_weights(self, weights: Iterable[tuple[str,
459464
if "mtp" in name:
460465
continue
461466

467+
if "e_score_correction_bias" in name:
468+
name = name.replace("moe_statics", "gate")
469+
loaded_weight = loaded_weight.squeeze(0)
470+
462471
for (param_name, weight_name, shard_id) in stacked_params_mapping:
463472
# Skip non-stacked layers and experts (experts handled below).
464473
if weight_name not in name:

0 commit comments

Comments
 (0)