From c1b0eb01b8807e561af8aaa959f932fc4d0643fa Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 19 Dec 2023 19:18:45 +0000 Subject: [PATCH] Switch from Mistral checkpoint to HF-Mistral. Signed-off-by: Alexandros Koumparoulis --- ...mo.py => convert_hf_mistral_7b_to_nemo.py} | 65 ++++++++++--------- 1 file changed, 35 insertions(+), 30 deletions(-) rename scripts/nlp_language_modeling/{convert_mistral_7b_to_nemo.py => convert_hf_mistral_7b_to_nemo.py} (85%) diff --git a/scripts/nlp_language_modeling/convert_mistral_7b_to_nemo.py b/scripts/nlp_language_modeling/convert_hf_mistral_7b_to_nemo.py similarity index 85% rename from scripts/nlp_language_modeling/convert_mistral_7b_to_nemo.py rename to scripts/nlp_language_modeling/convert_hf_mistral_7b_to_nemo.py index 6ace24acd1f9..b7478b5725b7 100644 --- a/scripts/nlp_language_modeling/convert_mistral_7b_to_nemo.py +++ b/scripts/nlp_language_modeling/convert_hf_mistral_7b_to_nemo.py @@ -15,12 +15,13 @@ r""" Conversion script to convert Huggingface Mistral-7b checkpoints into nemo checkpoint. Example to run this conversion script: - python convert_mistral_7b_to_nemo.py \ + python convert_hf_mistral_7b_to_nemo.py \ --in-file \ --out-file \ [--fast-swiglu\ """ + import json import os from argparse import ArgumentParser @@ -97,26 +98,26 @@ def load_config(mistral_config, tokenizer_path): ).model # akoumparouli: verify this. nemo_config.encoder_seq_length = mistral_config['sliding_window'] - nemo_config.num_layers = int(mistral_config['n_layers']) - nemo_config.hidden_size = mistral_config['dim'] - nemo_config.ffn_hidden_size = mistral_config['hidden_dim'] - nemo_config.num_attention_heads = mistral_config['n_heads'] - nemo_config.max_position_embeddings = 32_768 - nemo_config.window_size = mistral_config['sliding_window'] - nemo_config.init_method_std = 0.02 + nemo_config.num_layers = int(mistral_config['num_hidden_layers']) + nemo_config.hidden_size = mistral_config['hidden_size'] + nemo_config.ffn_hidden_size = mistral_config['intermediate_size'] + nemo_config.num_attention_heads = mistral_config['num_attention_heads'] + nemo_config.max_position_embeddings = mistral_config['max_position_embeddings'] + nemo_config.window_size = [mistral_config['sliding_window'], 0] + nemo_config.init_method_std = mistral_config['initializer_range'] # RMSNorm's epsilon. - nemo_config.layernorm_epsilon = mistral_config['norm_eps'] + nemo_config.layernorm_epsilon = mistral_config['rms_norm_eps'] nemo_config.normalization = 'RMSNorm' - if 'n_kv_heads' in mistral_config: - nemo_config.num_query_groups = mistral_config['n_kv_heads'] + if 'num_key_value_heads' in mistral_config: + nemo_config.num_query_groups = mistral_config['num_key_value_heads'] nemo_config.use_cpu_initialization = True # Mistral uses SiLU, but it is the same as swish with beta = 1. nemo_config.activation = 'fast-swiglu' nemo_config.tokenizer.model = tokenizer_path # TODO(@akoumparouli): rope_scaling. - nemo_config['rotary_base'] = 10000.0 + nemo_config['rotary_base'] = mistral_config['rope_theta'] base = 128 while mistral_config['vocab_size'] % base != 0: @@ -127,14 +128,18 @@ def load_config(mistral_config, tokenizer_path): def load_mistral_ckpt(dir): - params_file = os.path.join(dir, 'params.json') + params_file = os.path.join(dir, 'config.json') assert os.path.exists(params_file) with open(params_file, 'r') as fp: model_args = json.load(fp) - ckpt_file = os.path.join(dir, 'consolidated.00.pth') - assert os.path.exists(ckpt_file) - ckpt = torch.load(ckpt_file) + ckpt = OrderedDict() + ckpt['state_dict'] = OrderedDict() + for i in range(2): + ckpt_file = f'pytorch_model-0000{i+1}-of-00002.bin' + ckpt_path = os.path.join(dir, ckpt_file) + assert os.path.exists(ckpt_path) + ckpt.update(torch.load(ckpt_path, mmap=True)) tokenizer_file = os.path.join(dir, 'tokenizer.model') assert os.path.exists(tokenizer_file) tokenizer = SentencePieceProcessor(model_file=tokenizer_file) @@ -208,7 +213,7 @@ def convert(args): checkpoint = OrderedDict() checkpoint['state_dict'] = OrderedDict() - embed_weight = ckpt[f'tok_embeddings.weight'] + embed_weight = ckpt[f'model.embed_tokens.weight'] if mcore_gpt: embed_weights_base_name = f'model.embedding.word_embeddings.weight' else: @@ -225,13 +230,13 @@ def convert(args): for l in range(int(num_layers)): print(f"converting layer {l}") - old_tensor_shape = ckpt[f'layers.{l}.attention.wq.weight'].size() + old_tensor_shape = ckpt[f'model.layers.{l}.self_attn.q_proj.weight'].size() new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] - q = ckpt[f'layers.{l}.attention.wq.weight'].view(*new_q_tensor_shape) - k = ckpt[f'layers.{l}.attention.wk.weight'].view(*new_kv_tensor_shape) - v = ckpt[f'layers.{l}.attention.wv.weight'].view(*new_kv_tensor_shape) + q = ckpt[f'model.layers.{l}.self_attn.q_proj.weight'].view(*new_q_tensor_shape) + k = ckpt[f'model.layers.{l}.self_attn.k_proj.weight'].view(*new_kv_tensor_shape) + v = ckpt[f'model.layers.{l}.self_attn.v_proj.weight'].view(*new_kv_tensor_shape) qkv_weights = torch.empty((0, head_size) + old_tensor_shape[1:]) heads_per_group = head_num // num_query_groups for i in range(num_query_groups): @@ -247,7 +252,7 @@ def convert(args): checkpoint['state_dict'][qkv_weights_base_name] = param_to_weights(qkv_weights) # attention dense - o_weight = ckpt[f'layers.{l}.attention.wo.weight'] + o_weight = ckpt[f'model.layers.{l}.self_attn.o_proj.weight'] if mcore_gpt: o_weight_base_name = f'model.decoder.layers.{l}.self_attention.linear_proj.weight' else: @@ -255,8 +260,8 @@ def convert(args): checkpoint['state_dict'][o_weight_base_name] = param_to_weights(o_weight) # MLP - mlp_down_weight = ckpt[f'layers.{l}.feed_forward.w1.weight'] - mlp_gate_weight = ckpt[f'layers.{l}.feed_forward.w3.weight'] + mlp_down_weight = ckpt[f'model.layers.{l}.mlp.gate_proj.weight'] + mlp_gate_weight = ckpt[f'model.layers.{l}.mlp.up_proj.weight'] if mcore_gpt: mlp_down_base_name = f'model.decoder.layers.{l}.mlp.linear_fc1.weight' else: @@ -264,7 +269,7 @@ def convert(args): mlp_down_weight = torch.cat((mlp_down_weight, mlp_gate_weight), axis=0) checkpoint['state_dict'][mlp_down_base_name] = param_to_weights(mlp_down_weight) - mlp_up_weight = ckpt[f'layers.{l}.feed_forward.w2.weight'] + mlp_up_weight = ckpt[f'model.layers.{l}.mlp.down_proj.weight'] if mcore_gpt: mlp_up_base_name = f'model.decoder.layers.{l}.mlp.linear_fc2.weight' else: @@ -272,7 +277,7 @@ def convert(args): checkpoint['state_dict'][mlp_up_base_name] = param_to_weights(mlp_up_weight) # LayerNorm - input_ln_weight = ckpt[f'layers.{l}.attention_norm.weight'] + input_ln_weight = ckpt[f'model.layers.{l}.input_layernorm.weight'] if mcore_gpt: input_ln_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight' @@ -280,7 +285,7 @@ def convert(args): input_ln_base_name = f'model.language_model.encoder.layers.{l}.input_layernorm.weight' checkpoint['state_dict'][input_ln_base_name] = param_to_weights(input_ln_weight) - post_attn_ln_weight = ckpt[f'layers.{l}.ffn_norm.weight'] + post_attn_ln_weight = ckpt[f'model.layers.{l}.post_attention_layernorm.weight'] if mcore_gpt: post_attn_ln_base_name = f'model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight' else: @@ -289,14 +294,14 @@ def convert(args): print(f"done layer {l}") - final_ln_weight = ckpt[f'norm.weight'] + final_ln_weight = ckpt[f'model.norm.weight'] if mcore_gpt: final_ln_base_name = f'model.decoder.final_layernorm.weight' else: final_ln_base_name = f'model.language_model.encoder.final_layernorm.weight' checkpoint['state_dict'][final_ln_base_name] = param_to_weights(final_ln_weight) - output_layer_weight = ckpt[f'output.weight'] + output_layer_weight = ckpt[f'lm_head.weight'] if mcore_gpt: output_layer_base_name = f'model.output_layer.weight' else: @@ -304,7 +309,7 @@ def convert(args): checkpoint['state_dict'][output_layer_base_name] = param_to_weights(output_layer_weight) checkpoint[MegatronGPTModel.CHECKPOINT_HYPER_PARAMS_KEY] = nemo_config - + torch.save(checkpoint, 'ckpt') del ckpt if nemo_config.get('megatron_amp_O2', False):