@@ -123,14 +123,19 @@ def __init__(
123123 quant_config = None ,
124124 prefix = f"{ prefix } .gate" )
125125
126- self .experts = FusedMoE (num_experts = config .moe_num_experts ,
127- top_k = config .moe_k ,
128- hidden_size = config .hidden_size ,
129- intermediate_size = config .moe_intermediate_size ,
130- reduce_results = False ,
131- renormalize = True ,
132- quant_config = quant_config ,
133- prefix = f"{ prefix } .experts" )
126+ self .gate .e_score_correction_bias = nn .Parameter (
127+ torch .empty (config .moe_num_experts ))
128+
129+ self .experts = FusedMoE (
130+ num_experts = config .moe_num_experts ,
131+ top_k = config .moe_k ,
132+ hidden_size = config .hidden_size ,
133+ intermediate_size = config .moe_intermediate_size ,
134+ reduce_results = False ,
135+ renormalize = True ,
136+ quant_config = quant_config ,
137+ prefix = f"{ prefix } .experts" ,
138+ e_score_correction_bias = self .gate .e_score_correction_bias )
134139
135140 if self .moe_num_shared_experts is not None :
136141 intermediate_size = (config .moe_intermediate_size *
@@ -459,6 +464,10 @@ def load_weights(self, weights: Iterable[tuple[str,
459464 if "mtp" in name :
460465 continue
461466
467+ if "e_score_correction_bias" in name :
468+ name = name .replace ("moe_statics" , "gate" )
469+ loaded_weight = loaded_weight .squeeze (0 )
470+
462471 for (param_name , weight_name , shard_id ) in stacked_params_mapping :
463472 # Skip non-stacked layers and experts (experts handled below).
464473 if weight_name not in name :
0 commit comments