Skip to content

Commit 627327f

Browse files
committed
Fail on MTP layer not found.
1 parent 3d60377 commit 627327f

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

python/sglang/srt/models/ernie4.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,12 @@ def __init__(
5252
self.weight = nn.Parameter(
5353
torch.empty((config.moe_num_experts, config.hidden_size))
5454
)
55-
self.e_score_correction_bias = nn.Parameter(
56-
torch.empty((1, config.moe_num_experts))
57-
)
55+
if getattr(config, "moe_use_aux_free", False):
56+
self.e_score_correction_bias = nn.Parameter(
57+
torch.empty((1, config.moe_num_experts))
58+
)
59+
else:
60+
self.e_score_correction_bias = None
5861

5962
def forward(self, hidden_states):
6063
logits = F.linear(hidden_states, self.weight, None)

python/sglang/srt/models/ernie4_mtp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(
5252
self.mtp_emb_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
5353
self.mtp_hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
5454
self.mtp_linear_proj = nn.Linear(
55-
config.hidden_size * 2, config.hidden_size, bias=False
55+
config.hidden_size * 2, config.hidden_size, bias=config.use_bias
5656
)
5757
self.mtp_block = Ernie4DecoderLayer(
5858
config=config,
@@ -139,6 +139,7 @@ def forward(
139139
)
140140

141141
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
142+
mtp_layer_found = False
142143
mtp_weight_patterns = [
143144
f"mtp_block.{self.mtp_layer_id}",
144145
f"mtp_emb_norm.{self.mtp_layer_id}",
@@ -150,11 +151,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
150151
# Only name matched patterns should be loaded
151152
for layer_pattern in mtp_weight_patterns:
152153
if layer_pattern in name:
154+
mtp_layer_found = True
153155
break
154156
else:
155157
continue
156158
# But strip mtp_layer_id before loading, because each MTP layer is a MTP model.
157-
name = name.replace(f".{self.mtp_layer_id}", "")
159+
name = name.replace(f".{self.mtp_layer_id}.", ".")
158160
for (
159161
param_name,
160162
weight_name,
@@ -176,6 +178,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
176178
weight_loader(param, loaded_weight)
177179
else:
178180
raise KeyError(f"Parameter '{name}' not found in MTP model.")
181+
if not mtp_layer_found:
182+
raise KeyError(f"MTP layers 'mtp_*.{self.mtp_layer_id}.*' not found in weights.")
179183

180184
def get_embed_and_head(self):
181185
return self.model.embed_tokens.weight, self.lm_head.weight

0 commit comments

Comments
 (0)