@@ -52,7 +52,7 @@ def __init__(
5252 self .mtp_emb_norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
5353 self .mtp_hidden_norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
5454 self .mtp_linear_proj = nn .Linear (
55- config .hidden_size * 2 , config .hidden_size , bias = False
55+ config .hidden_size * 2 , config .hidden_size , bias = config . use_bias
5656 )
5757 self .mtp_block = Ernie4DecoderLayer (
5858 config = config ,
@@ -139,6 +139,7 @@ def forward(
139139 )
140140
141141 def load_weights (self , weights : Iterable [Tuple [str , torch .Tensor ]]):
142+ mtp_layer_found = False
142143 mtp_weight_patterns = [
143144 f"mtp_block.{ self .mtp_layer_id } " ,
144145 f"mtp_emb_norm.{ self .mtp_layer_id } " ,
@@ -150,11 +151,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
150151 # Only name matched patterns should be loaded
151152 for layer_pattern in mtp_weight_patterns :
152153 if layer_pattern in name :
154+ mtp_layer_found = True
153155 break
154156 else :
155157 continue
156158 # But strip mtp_layer_id before loading, because each MTP layer is a MTP model.
157- name = name .replace (f".{ self .mtp_layer_id } " , "" )
159+ name = name .replace (f".{ self .mtp_layer_id } . " , ". " )
158160 for (
159161 param_name ,
160162 weight_name ,
@@ -176,6 +178,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
176178 weight_loader (param , loaded_weight )
177179 else :
178180 raise KeyError (f"Parameter '{ name } ' not found in MTP model." )
181+ if not mtp_layer_found :
182+ raise KeyError (f"MTP layers 'mtp_*.{ self .mtp_layer_id } .*' not found in weights." )
179183
180184 def get_embed_and_head (self ):
181185 return self .model .embed_tokens .weight , self .lm_head .weight
0 commit comments