Skip to content

Commit

Permalink
fix on base layer config and missing state dict due to dist ckpt
Browse files Browse the repository at this point in the history
  • Loading branch information
xuanzic committed Nov 17, 2023
1 parent 39b78a9 commit c85f3ac
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions scripts/nlp_language_modeling/convert_nemo_falcon_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ def get_original_key(new_key):

key = new_key.replace("decoder.layers", "transformer.h")

if model.cfg.new_decoder_architecture:
if model.cfg.mcore_customization_config.new_decoder_architecture:
key = key.replace("input_layernorm", "ln_attn")
key = key.replace("pre_mlp_layernorm", "ln_mlp")
else:
key = key.replace("input_layernorm", "input_layernorm")
if not model.cfg.parallel_attention:
if not model.cfg.mcore_customization_config.parallel_attention:
key = key.replace("post_self_attn_layernorm", "post_attention_layernorm")

key = key.replace("self_attention.linear_proj", "self_attention.dense")
Expand All @@ -144,7 +144,10 @@ def get_original_key(new_key):
prefix = 'model.module.' if any(k.startswith('model.module.') for k in model.state_dict()) else 'model.'

for key, value in model.state_dict().items():
if '_extra_state' in key:
continue
orig_key = get_original_key(key)
print(f'Converting {key} to {orig_key}')
checkpoint['state_dict'][orig_key] = param_to_weights(value)

os.makedirs(os.path.dirname(output_hf_file), exist_ok=True)
Expand Down

0 comments on commit c85f3ac

Please sign in to comment.