Skip to content

Commit d15dbe9

Browse files
authored
Fix load_hf_ckpt in paddleformers (#2744)
1 parent d17af14 commit d15dbe9

File tree

1 file changed

+39
-106
lines changed

1 file changed

+39
-106
lines changed

examples/experiments/deepseek_v3_pretrain/load_hf_ckpt.py

Lines changed: 39 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,14 @@
2828
except:
2929
safe_open = None
3030

31-
_LAYER_RE = re.compile(r"^_layers\.(\d+)\.(\d+)(?:\.(.*))?$")
31+
_LAYER_RE = re.compile(r"^deepseek_v2.layers\.(\d+)\.(.*)$")
3232
_EXPERT_W1_RE = re.compile(r"^mlp\.experts\.(\d+)\.w1(?:\.weight)?$")
3333
_EXPERT_W2_RE = re.compile(r"^mlp\.experts\.(\d+)\.w2(?:\.weight)?$")
3434
_SHARE_EXPERT_W1_RE = re.compile(r"^mlp\.shared_experts\.w1(?:\.weight)?$")
3535
_SHARE_EXPERT_W2_RE = re.compile(r"^mlp\.shared_experts\.w2(?:\.weight)?$")
3636

3737
_EXPERT_W1_RE_v2 = re.compile(r"^mlp\.experts\.(\d+)\.gate_up_fused_proj(?:\.weight)?$")
3838
_SHARE_EXPERT_W1_RE_v2 = re.compile(r"^mlp\.shared_experts\.gate_up_fused_proj(?:\.weight)?$")
39-
_LAYER_RE_v2 = re.compile(r"_layers.deepseek_v2.layers\.(\d+)\.(.*)$")
4039

4140
custom_name_map = {
4241
"self_attn.input_layernorm.weight": "input_layernorm.weight",
@@ -47,10 +46,16 @@
4746
"self_attn.memory_recompute_att.q_ln_weight": "self_attn.q_a_layernorm.weight",
4847
"self_attn.fused_rms_norm_linear.q_down_weight": "self_attn.q_a_proj.weight",
4948
"self_attn.memory_recompute_att.q_up_weight": "self_attn.q_b_proj.weight",
49+
"self_attn.input_layernorm.weight": "input_layernorm.weight",
50+
"mlp.gate.norm_weight": "post_attention_layernorm.weight",
51+
"mlp.router.weight": "mlp.gate.weight",
52+
"mlp.router.e_score_correction_bias": "mlp.gate.e_score_correction_bias",
53+
"mlp.router.norm_weight": "post_attention_layernorm.weight",
54+
"mlp.shared_experts.norm_weight": "post_attention_layernorm.weight",
5055
}
5156

5257

53-
def paddle_name_to_hf_names_ds_v2(paddle_name: str) -> List[str]:
58+
def paddle_name_to_hf_names(paddle_name: str) -> List[str]:
5459
"""
5560
Convert Paddle model parameter names to Hugging Face format name lists
5661
@@ -60,97 +65,24 @@ def paddle_name_to_hf_names_ds_v2(paddle_name: str) -> List[str]:
6065
Returns:
6166
List of parameter names in Hugging Face format (may be split into multiple parameters)
6267
"""
63-
if paddle_name == "_layers.deepseek_v2.embed_tokens.weight":
68+
69+
if paddle_name == "deepseek_v2.embed_tokens.weight":
6470
return ["model.embed_tokens.weight"]
6571

66-
if paddle_name == "_layers.deepseek_v2.norm.weight":
72+
if paddle_name == "deepseek_v2.norm.weight":
6773
return ["model.norm.weight"]
6874

69-
if paddle_name == "_layers.lm_head.weight":
75+
if paddle_name == "lm_head.weight":
7076
return ["lm_head.weight"]
7177

72-
m = _LAYER_RE_v2.match(paddle_name)
73-
if not m:
74-
return []
75-
76-
rest = m.group(2) or ""
77-
layer_id = m.group(1)
78-
if rest in custom_name_map:
79-
rest = custom_name_map[rest]
80-
out_name = "model.layers." + layer_id + "." + rest
81-
82-
if rest == "mlp.gate_up_fused_proj.weight" or rest == "mlp.w1":
83-
return [
84-
"model.layers." + layer_id + ".mlp.gate_proj.weight",
85-
"model.layers." + layer_id + ".mlp.up_proj.weight",
86-
]
87-
88-
if rest == "mlp.w2":
89-
return ["model.layers." + layer_id + ".mlp.down_proj.weight"]
90-
91-
if rest == "mlp.shared_experts.gate_up_fused_proj.weight":
92-
return [
93-
"model.layers." + layer_id + ".mlp.shared_experts.gate_proj.weight",
94-
"model.layers." + layer_id + ".mlp.shared_experts.up_proj.weight",
95-
]
96-
97-
if m := _EXPERT_W1_RE_v2.match(rest):
98-
expert_id = m.group(1)
99-
return [
100-
"model.layers." + layer_id + ".mlp.experts." + expert_id + ".gate_proj.weight",
101-
"model.layers." + layer_id + ".mlp.experts." + expert_id + ".up_proj.weight",
102-
]
103-
104-
if m := _EXPERT_W1_RE.match(rest):
105-
expert_id = m.group(1)
106-
return [
107-
"model.layers." + layer_id + ".mlp.experts." + expert_id + ".gate_proj.weight",
108-
"model.layers." + layer_id + ".mlp.experts." + expert_id + ".up_proj.weight",
109-
]
110-
111-
if m := _EXPERT_W2_RE.match(rest):
112-
expert_id = m.group(1)
113-
return ["model.layers." + layer_id + ".mlp.experts." + expert_id + ".down_proj.weight"]
114-
115-
if m := _SHARE_EXPERT_W1_RE.match(rest):
116-
return [
117-
"model.layers." + layer_id + ".mlp.shared_experts.gate_proj.weight",
118-
"model.layers." + layer_id + ".mlp.shared_experts.up_proj.weight",
119-
]
120-
121-
if m := _SHARE_EXPERT_W2_RE.match(rest):
122-
return ["model.layers." + layer_id + ".mlp.shared_experts.down_proj.weight"]
123-
124-
return [out_name]
125-
126-
127-
def paddle_name_to_hf_names(paddle_name: str) -> List[str]:
128-
"""
129-
Convert Paddle model parameter names to Hugging Face format name lists
130-
131-
Args:
132-
paddle_name: Parameter name in Paddle format
133-
134-
Returns:
135-
List of parameter names in Hugging Face format (may be split into multiple parameters)
136-
"""
137-
if paddle_name == "_layers.local_shared_layers.DeepseekV2_shared_weight.embed_tokens.weight":
138-
return ["model.embed_tokens.weight"]
139-
140-
if paddle_name == "_layers.deepseek_v2.embed_tokens.weight":
141-
return ["model.embed_tokens.weight"]
142-
14378
m = _LAYER_RE.match(paddle_name)
14479

14580
if not m:
14681
return []
14782
else:
148-
rest = m.group(3) or ""
83+
rest = m.group(2) or ""
14984

150-
segment_id = int(m.group(1))
151-
id_in_segment = int(m.group(2))
152-
153-
hf_prefix = _get_hf_prefix(segment_id, id_in_segment)
85+
hf_prefix = "model" + ".layers." + m.group(1)
15486

15587
if rest in custom_name_map:
15688
return [f"{hf_prefix}.{custom_name_map[rest]}"]
@@ -197,24 +129,7 @@ def paddle_name_to_hf_names(paddle_name: str) -> List[str]:
197129
if m := _SHARE_EXPERT_W2_RE.match(rest):
198130
return [hf_prefix + ".mlp.shared_experts.down_proj.weight"]
199131

200-
return [f"{hf_prefix}.{rest}"] if rest else [hf_prefix]
201-
202-
203-
def _get_hf_prefix(segment_id: int, id_in_segment: int) -> str:
204-
"""Generate hierarchical prefix in Hugging Face format"""
205-
# Special layer mappings
206-
# special_cases = {(0, 0): "model", (60, 2): "model.layers.61", (60, 3): "model"}
207-
# special_cases = {(0, 0): "model", (28, 2): "model.layers.61", (28, 3): "model"}
208-
# special_cases = {(0, 0): "model", (28, 2): "model.layers.61", (4, 1): "model"}
209-
# special_cases = {(0, 0): "model", (28, 2): "model", (28,3): "lm_head"}
210-
special_cases = {(0, 0): "model", (60, 2): "model.layers.61", (60, 3): "model", (60, 4): "lm_head"}
211-
212-
if (segment_id, id_in_segment) in special_cases:
213-
return special_cases[(segment_id, id_in_segment)]
214-
215-
# General layer calculation
216-
layer_idx = segment_id + id_in_segment - 1
217-
return f"model.layers.{layer_idx}"
132+
return [paddle_name.replace("deepseek_v2", "model")]
218133

219134

220135
def _handle_expert_weights(hf_prefix: str, rest: str) -> Optional[List[str]]:
@@ -255,6 +170,25 @@ def _handle_mlp_weights(hf_prefix: str, rest: str) -> Optional[List[str]]:
255170
return None
256171

257172

173+
def _is_need_transpose(key):
174+
transpose_weight_keys = [
175+
"fused_rms_norm_linear.kv_down_weight",
176+
"memory_recompute_att.kv_up_weight",
177+
"o_proj.weight",
178+
"fused_rms_norm_linear.q_down_weight",
179+
"memory_recompute_att.q_up_weight",
180+
"w1",
181+
"w2",
182+
"gate.weight",
183+
"eh_proj.weight",
184+
"lm_head.weight",
185+
]
186+
for trans_key in transpose_weight_keys:
187+
if key.endswith(trans_key):
188+
return True
189+
return False
190+
191+
258192
def prepare_tensor(tensor, dst_shape, *, force_transpose=False):
259193
if isinstance(tensor, list):
260194
t = paddle.cat(
@@ -300,18 +234,17 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
300234
required_files = set()
301235
file_to_pd_param_name = defaultdict(list)
302236
pd_param_name_to_file = defaultdict(list)
303-
for pd_name, p in model.named_parameters():
237+
for pd_name, p in model.state_dict().items():
304238
hf_name = paddle_name_to_hf_names(pd_name)
305-
if hf_name[0] in weight_map:
239+
if len(hf_name) == 0:
240+
logger.warning(f"the weight {pd_name} does not need to be loaded")
241+
elif hf_name[0] in weight_map:
306242
filename = weight_map[hf_name[0]]
307243
required_files.add(filename)
308244
file_to_pd_param_name[filename].append(pd_name)
309245
pd_param_name_to_file[pd_name].append(filename)
310246
else:
311247
logger.warning(f"Warning: {pd_name} -> {hf_name[0]} not found in weight map")
312-
import sys
313-
314-
sys.exit()
315248

316249
if len(hf_name) > 1:
317250
if hf_name[1] in weight_map:
@@ -339,7 +272,7 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
339272
if len(hf_name) == 1:
340273
tensor = f.get_tensor(hf_name[0])
341274

342-
force_transpose = False
275+
force_transpose = _is_need_transpose(hf_name[0])
343276

344277
model.state_dict()[pd_param].set_value(
345278
paddle.cast(

0 commit comments

Comments
 (0)