Skip to content

Commit 79bc929

Browse files
authored
fix issues with convert_nemo_llama_to_hf.py (NVIDIA#7922)
1 parent 521cfb4 commit 79bc929

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

scripts/nlp_language_modeling/convert_nemo_llama_to_hf.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) ->
9191
map_location = torch.device('cpu')
9292
model_config = MegatronGPTModel.restore_from(input_nemo_file, trainer=dummy_trainer, return_config=True)
9393
model_config.use_cpu_initialization = True
94+
model_config.tensor_model_parallel_size = 1
9495
else:
9596
map_location, model_config = None, None
9697

@@ -113,7 +114,6 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) ->
113114

114115
param_to_weights = lambda param: param.to(dtype)
115116
checkpoint = OrderedDict()
116-
checkpoint['state_dict'] = OrderedDict()
117117

118118
hidden_size = model.cfg.hidden_size
119119
head_num = model.cfg.num_attention_heads
@@ -128,7 +128,7 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) ->
128128
# Embedding
129129
embed_weight = model.state_dict()[f'model.embedding.word_embeddings.weight']
130130
embed_weights_base_name = f'model.embed_tokens.weight'
131-
checkpoint['state_dict'][embed_weights_base_name] = param_to_weights(embed_weight)
131+
checkpoint[embed_weights_base_name] = param_to_weights(embed_weight)
132132

133133
for l in range(int(num_layers)):
134134
print(f"converting layer {l}")
@@ -158,14 +158,14 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) ->
158158
k_weights_base_name = f'model.layers.{l}.self_attn.k_proj.weight'
159159
v_weights_base_name = f'model.layers.{l}.self_attn.v_proj.weight'
160160

161-
checkpoint['state_dict'][q_weights_base_name] = param_to_weights(qkv_weights[q_slice].reshape(-1, hidden_size))
162-
checkpoint['state_dict'][k_weights_base_name] = param_to_weights(qkv_weights[k_slice].reshape(-1, hidden_size))
163-
checkpoint['state_dict'][v_weights_base_name] = param_to_weights(qkv_weights[v_slice].reshape(-1, hidden_size))
161+
checkpoint[q_weights_base_name] = param_to_weights(qkv_weights[q_slice].reshape(-1, hidden_size))
162+
checkpoint[k_weights_base_name] = param_to_weights(qkv_weights[k_slice].reshape(-1, hidden_size))
163+
checkpoint[v_weights_base_name] = param_to_weights(qkv_weights[v_slice].reshape(-1, hidden_size))
164164

165165
# attention dense
166166
o_weight = model.state_dict()[f'model.decoder.layers.{l}.self_attention.linear_proj.weight']
167167
o_weight_base_name = f'model.layers.{l}.self_attn.o_proj.weight'
168-
checkpoint['state_dict'][o_weight_base_name] = param_to_weights(o_weight)
168+
checkpoint[o_weight_base_name] = param_to_weights(o_weight)
169169

170170
# mlp
171171
mlp_weights = model.state_dict()[f'model.decoder.layers.{l}.mlp.linear_fc1.weight']
@@ -175,31 +175,31 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) ->
175175
mlp_down_proj_base_name = f'model.layers.{l}.mlp.gate_proj.weight'
176176
mlp_gate_proj_base_name = f'model.layers.{l}.mlp.up_proj.weight'
177177

178-
checkpoint['state_dict'][mlp_down_proj_base_name] = param_to_weights(mlp_down_proj_weight)
179-
checkpoint['state_dict'][mlp_gate_proj_base_name] = param_to_weights(mlp_gate_proj_weight)
178+
checkpoint[mlp_down_proj_base_name] = param_to_weights(mlp_down_proj_weight)
179+
checkpoint[mlp_gate_proj_base_name] = param_to_weights(mlp_gate_proj_weight)
180180

181181
mlp_up_proj_weight = model.state_dict()[f'model.decoder.layers.{l}.mlp.linear_fc2.weight']
182182
mlp_up_proj_base_name = f'model.layers.{l}.mlp.down_proj.weight'
183-
checkpoint['state_dict'][mlp_up_proj_base_name] = param_to_weights(mlp_up_proj_weight)
183+
checkpoint[mlp_up_proj_base_name] = param_to_weights(mlp_up_proj_weight)
184184

185185
# layernorm
186186
input_ln_weight = model.state_dict()[f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight']
187187
input_ln_base_name = f'model.layers.{l}.input_layernorm.weight'
188-
checkpoint['state_dict'][input_ln_base_name] = param_to_weights(input_ln_weight)
188+
checkpoint[input_ln_base_name] = param_to_weights(input_ln_weight)
189189

190190
post_attn_ln_weight = model.state_dict()[f'model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight']
191191
post_attn_ln_base_name = f'model.layers.{l}.post_attention_layernorm.weight'
192-
checkpoint['state_dict'][post_attn_ln_base_name] = param_to_weights(post_attn_ln_weight)
192+
checkpoint[post_attn_ln_base_name] = param_to_weights(post_attn_ln_weight)
193193

194194
print(f"done layer {l}")
195195

196196
final_ln_weight = model.state_dict()[f'model.decoder.final_layernorm.weight']
197197
final_ln_base_name = f'model.norm.weight'
198-
checkpoint['state_dict'][final_ln_base_name] = param_to_weights(final_ln_weight)
198+
checkpoint[final_ln_base_name] = param_to_weights(final_ln_weight)
199199

200200
output_layer_weight = model.state_dict()[f'model.output_layer.weight']
201201
output_layer_base_name = f'lm_head.weight'
202-
checkpoint['state_dict'][output_layer_base_name] = param_to_weights(output_layer_weight)
202+
checkpoint[output_layer_base_name] = param_to_weights(output_layer_weight)
203203

204204
os.makedirs(os.path.dirname(output_hf_file), exist_ok=True)
205205
torch.save(checkpoint, output_hf_file)
@@ -210,7 +210,7 @@ def replace_hf_weights(weights_file, input_hf_path, output_hf_path):
210210
model = AutoModelForCausalLM.from_pretrained(input_hf_path, local_files_only=True)
211211
nemo_exported = torch.load(weights_file)
212212

213-
model.load_state_dict(nemo_exported['state_dict'])
213+
model.load_state_dict(nemo_exported)
214214
model.save_pretrained(output_hf_path)
215215
logging.info(f"Full HF model saved to {output_hf_path}")
216216

0 commit comments

Comments
 (0)