Skip to content

Commit

Permalink
Switch from Mistral checkpoint to HF-Mistral.
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <[email protected]>
  • Loading branch information
akoumpa committed Dec 20, 2023
1 parent 6fe099f commit c1b0eb0
Showing 1 changed file with 35 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
r"""
Conversion script to convert Huggingface Mistral-7b checkpoints into nemo checkpoint.
Example to run this conversion script:
python convert_mistral_7b_to_nemo.py \
python convert_hf_mistral_7b_to_nemo.py \
--in-file <path_to_mistral_checkpoints_folder> \
--out-file <path_to_output_nemo_file> \
[--fast-swiglu\
"""


import json
import os
from argparse import ArgumentParser
Expand Down Expand Up @@ -97,26 +98,26 @@ def load_config(mistral_config, tokenizer_path):
).model
# akoumparouli: verify this.
nemo_config.encoder_seq_length = mistral_config['sliding_window']
nemo_config.num_layers = int(mistral_config['n_layers'])
nemo_config.hidden_size = mistral_config['dim']
nemo_config.ffn_hidden_size = mistral_config['hidden_dim']
nemo_config.num_attention_heads = mistral_config['n_heads']
nemo_config.max_position_embeddings = 32_768
nemo_config.window_size = mistral_config['sliding_window']
nemo_config.init_method_std = 0.02
nemo_config.num_layers = int(mistral_config['num_hidden_layers'])
nemo_config.hidden_size = mistral_config['hidden_size']
nemo_config.ffn_hidden_size = mistral_config['intermediate_size']
nemo_config.num_attention_heads = mistral_config['num_attention_heads']
nemo_config.max_position_embeddings = mistral_config['max_position_embeddings']
nemo_config.window_size = [mistral_config['sliding_window'], 0]
nemo_config.init_method_std = mistral_config['initializer_range']
# RMSNorm's epsilon.
nemo_config.layernorm_epsilon = mistral_config['norm_eps']
nemo_config.layernorm_epsilon = mistral_config['rms_norm_eps']
nemo_config.normalization = 'RMSNorm'

if 'n_kv_heads' in mistral_config:
nemo_config.num_query_groups = mistral_config['n_kv_heads']
if 'num_key_value_heads' in mistral_config:
nemo_config.num_query_groups = mistral_config['num_key_value_heads']
nemo_config.use_cpu_initialization = True
# Mistral uses SiLU, but it is the same as swish with beta = 1.
nemo_config.activation = 'fast-swiglu'

nemo_config.tokenizer.model = tokenizer_path
# TODO(@akoumparouli): rope_scaling.
nemo_config['rotary_base'] = 10000.0
nemo_config['rotary_base'] = mistral_config['rope_theta']

base = 128
while mistral_config['vocab_size'] % base != 0:
Expand All @@ -127,14 +128,18 @@ def load_config(mistral_config, tokenizer_path):


def load_mistral_ckpt(dir):
params_file = os.path.join(dir, 'params.json')
params_file = os.path.join(dir, 'config.json')
assert os.path.exists(params_file)
with open(params_file, 'r') as fp:
model_args = json.load(fp)

ckpt_file = os.path.join(dir, 'consolidated.00.pth')
assert os.path.exists(ckpt_file)
ckpt = torch.load(ckpt_file)
ckpt = OrderedDict()
ckpt['state_dict'] = OrderedDict()
for i in range(2):
ckpt_file = f'pytorch_model-0000{i+1}-of-00002.bin'
ckpt_path = os.path.join(dir, ckpt_file)
assert os.path.exists(ckpt_path)
ckpt.update(torch.load(ckpt_path, mmap=True))
tokenizer_file = os.path.join(dir, 'tokenizer.model')
assert os.path.exists(tokenizer_file)
tokenizer = SentencePieceProcessor(model_file=tokenizer_file)
Expand Down Expand Up @@ -208,7 +213,7 @@ def convert(args):
checkpoint = OrderedDict()
checkpoint['state_dict'] = OrderedDict()

embed_weight = ckpt[f'tok_embeddings.weight']
embed_weight = ckpt[f'model.embed_tokens.weight']
if mcore_gpt:
embed_weights_base_name = f'model.embedding.word_embeddings.weight'
else:
Expand All @@ -225,13 +230,13 @@ def convert(args):

for l in range(int(num_layers)):
print(f"converting layer {l}")
old_tensor_shape = ckpt[f'layers.{l}.attention.wq.weight'].size()
old_tensor_shape = ckpt[f'model.layers.{l}.self_attn.q_proj.weight'].size()
new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:]
new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:]

q = ckpt[f'layers.{l}.attention.wq.weight'].view(*new_q_tensor_shape)
k = ckpt[f'layers.{l}.attention.wk.weight'].view(*new_kv_tensor_shape)
v = ckpt[f'layers.{l}.attention.wv.weight'].view(*new_kv_tensor_shape)
q = ckpt[f'model.layers.{l}.self_attn.q_proj.weight'].view(*new_q_tensor_shape)
k = ckpt[f'model.layers.{l}.self_attn.k_proj.weight'].view(*new_kv_tensor_shape)
v = ckpt[f'model.layers.{l}.self_attn.v_proj.weight'].view(*new_kv_tensor_shape)
qkv_weights = torch.empty((0, head_size) + old_tensor_shape[1:])
heads_per_group = head_num // num_query_groups
for i in range(num_query_groups):
Expand All @@ -247,40 +252,40 @@ def convert(args):
checkpoint['state_dict'][qkv_weights_base_name] = param_to_weights(qkv_weights)

# attention dense
o_weight = ckpt[f'layers.{l}.attention.wo.weight']
o_weight = ckpt[f'model.layers.{l}.self_attn.o_proj.weight']
if mcore_gpt:
o_weight_base_name = f'model.decoder.layers.{l}.self_attention.linear_proj.weight'
else:
o_weight_base_name = f'model.language_model.encoder.layers.{l}.self_attention.dense.weight'
checkpoint['state_dict'][o_weight_base_name] = param_to_weights(o_weight)

# MLP
mlp_down_weight = ckpt[f'layers.{l}.feed_forward.w1.weight']
mlp_gate_weight = ckpt[f'layers.{l}.feed_forward.w3.weight']
mlp_down_weight = ckpt[f'model.layers.{l}.mlp.gate_proj.weight']
mlp_gate_weight = ckpt[f'model.layers.{l}.mlp.up_proj.weight']
if mcore_gpt:
mlp_down_base_name = f'model.decoder.layers.{l}.mlp.linear_fc1.weight'
else:
mlp_down_base_name = f'model.language_model.encoder.layers.{l}.mlp.dense_h_to_4h.weight'
mlp_down_weight = torch.cat((mlp_down_weight, mlp_gate_weight), axis=0)
checkpoint['state_dict'][mlp_down_base_name] = param_to_weights(mlp_down_weight)

mlp_up_weight = ckpt[f'layers.{l}.feed_forward.w2.weight']
mlp_up_weight = ckpt[f'model.layers.{l}.mlp.down_proj.weight']
if mcore_gpt:
mlp_up_base_name = f'model.decoder.layers.{l}.mlp.linear_fc2.weight'
else:
mlp_up_base_name = f'model.language_model.encoder.layers.{l}.mlp.dense_4h_to_h.weight'
checkpoint['state_dict'][mlp_up_base_name] = param_to_weights(mlp_up_weight)

# LayerNorm
input_ln_weight = ckpt[f'layers.{l}.attention_norm.weight']
input_ln_weight = ckpt[f'model.layers.{l}.input_layernorm.weight']

if mcore_gpt:
input_ln_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight'
else:
input_ln_base_name = f'model.language_model.encoder.layers.{l}.input_layernorm.weight'
checkpoint['state_dict'][input_ln_base_name] = param_to_weights(input_ln_weight)

post_attn_ln_weight = ckpt[f'layers.{l}.ffn_norm.weight']
post_attn_ln_weight = ckpt[f'model.layers.{l}.post_attention_layernorm.weight']
if mcore_gpt:
post_attn_ln_base_name = f'model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight'
else:
Expand All @@ -289,22 +294,22 @@ def convert(args):

print(f"done layer {l}")

final_ln_weight = ckpt[f'norm.weight']
final_ln_weight = ckpt[f'model.norm.weight']
if mcore_gpt:
final_ln_base_name = f'model.decoder.final_layernorm.weight'
else:
final_ln_base_name = f'model.language_model.encoder.final_layernorm.weight'
checkpoint['state_dict'][final_ln_base_name] = param_to_weights(final_ln_weight)

output_layer_weight = ckpt[f'output.weight']
output_layer_weight = ckpt[f'lm_head.weight']
if mcore_gpt:
output_layer_base_name = f'model.output_layer.weight'
else:
output_layer_base_name = f'model.language_model.output_layer.weight'
checkpoint['state_dict'][output_layer_base_name] = param_to_weights(output_layer_weight)

checkpoint[MegatronGPTModel.CHECKPOINT_HYPER_PARAMS_KEY] = nemo_config

torch.save(checkpoint, 'ckpt')
del ckpt

if nemo_config.get('megatron_amp_O2', False):
Expand Down

0 comments on commit c1b0eb0

Please sign in to comment.