forked from NVIDIA/NeMo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Nemo to HF converter for LLaMA model (NVIDIA#7770)
* Create config_llama_truncate.yaml Signed-off-by: Utkarsh <[email protected]> * Add files via upload Signed-off-by: Utkarsh <[email protected]> * Update convert_nemo_llama_to_hf.py Signed-off-by: Utkarsh <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update config_llama_truncate.yaml Signed-off-by: Utkarsh <[email protected]> * Update convert_nemo_llama_to_hf.py Signed-off-by: Utkarsh <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update convert_nemo_llama_to_hf.py Signed-off-by: Utkarsh <[email protected]> * clean up trainer * remove dependency on yaml config. load config from nemo file instead. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * enable ckpt saving into other precision formats * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support 70b + cleanup qkv slice logic * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix bug * move hf model folder code from comment to function and add instruction to run * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Utkarsh <[email protected]> Signed-off-by: Chen Cui <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Harper <[email protected]> Co-authored-by: Chen Cui <[email protected]> Signed-off-by: Piotr Żelasko <[email protected]>
- Loading branch information
Showing
1 changed file
with
225 additions
and
0 deletions.
There are no files selected for viewing
225 changes: 225 additions & 0 deletions
225
scripts/nlp_language_modeling/convert_nemo_llama_to_hf.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
from argparse import ArgumentParser | ||
from collections import OrderedDict | ||
|
||
import torch | ||
from pytorch_lightning import Trainer | ||
from transformers import AutoModelForCausalLM | ||
|
||
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel | ||
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy | ||
from nemo.utils import logging | ||
|
||
""" | ||
Script to convert a llama2 checkpoint in nemo (mcore path) into a HuggingFace checkpoint. | ||
This script can be used to 1) generate only the HF weights, or 2) generate an entire HF model folder. | ||
1) Generate only HF weights from a nemo file: | ||
python convert_nemo_llama_to_hf.py \ | ||
--in-file /path/to/file.nemo or /path/to/extracted_folder \ | ||
--out-file /path/to/pytorch_model.bin | ||
2) Generate the full HF model folder | ||
python convert_nemo_llama_to_hf.py \ | ||
--in-file /path/to/file.nemo or /path/to/extracted_folder \ | ||
--out-file /path/to/pytorch_model.bin \ | ||
--hf-in-file /path/to/input_hf_folder \ | ||
--hf-out-file /path/to/output_hf_folder | ||
Use the --cpu-only flag if the model cannot fit in the GPU (e.g. Llama2 70b). | ||
However this option makes the conversion script significantly slower. | ||
""" | ||
|
||
|
||
def get_args(): | ||
parser = ArgumentParser() | ||
parser.add_argument( | ||
"--in-file", type=str, default=None, required=True, help="Path to .nemo file", | ||
) | ||
parser.add_argument("--out-file", type=str, default=None, required=True, help="Path to HF .bin file") | ||
parser.add_argument( | ||
"--hf-in-path", | ||
type=str, | ||
default=None, | ||
help="A HF model path, " "e.g. a folder containing https://huggingface.co/meta-llama/Llama-2-7b-hf/tree/main", | ||
) | ||
parser.add_argument( | ||
"--hf-out-path", | ||
type=str, | ||
default=None, | ||
help="Output HF model path, " "with the same format as above but user's own weights", | ||
) | ||
parser.add_argument( | ||
"--precision", | ||
type=str, | ||
default=None, | ||
help="Precision of output weights." | ||
"Defaults to precision of the input nemo weights (model.cfg.trainer.precision)", | ||
) | ||
parser.add_argument( | ||
"--cpu-only", | ||
action="store_true", | ||
help="Load model in cpu only. Useful if the model cannot fit in GPU memory, " | ||
"but this option makes the conversion script significantly slower.", | ||
) | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> None: | ||
""" | ||
Convert NeMo weights to HF weights | ||
""" | ||
dummy_trainer = Trainer(devices=1, accelerator='cpu', strategy=NLPDDPStrategy()) | ||
if cpu_only: | ||
map_location = torch.device('cpu') | ||
model_config = MegatronGPTModel.restore_from(input_nemo_file, trainer=dummy_trainer, return_config=True) | ||
model_config.use_cpu_initialization = True | ||
else: | ||
map_location, model_config = None, None | ||
|
||
if cpu_only: | ||
logging.info("******** Loading model on CPU. This will take a significant amount of time.") | ||
model = MegatronGPTModel.restore_from( | ||
input_nemo_file, trainer=dummy_trainer, override_config_path=model_config, map_location=map_location | ||
) | ||
if precision is None: | ||
precision = model.cfg.precision | ||
if precision in [32, "32"]: | ||
dtype = torch.float32 | ||
elif precision in [16, "16", "16-mixed"]: | ||
dtype = torch.float16 | ||
elif precision in ["bf16", "bf16-mixed"]: | ||
dtype = torch.bfloat16 | ||
else: | ||
logging.warning(f"Precision string {precision} is not recognized, falling back to fp32") | ||
dtype = torch.float32 # fallback | ||
|
||
param_to_weights = lambda param: param.to(dtype) | ||
checkpoint = OrderedDict() | ||
checkpoint['state_dict'] = OrderedDict() | ||
|
||
hidden_size = model.cfg.hidden_size | ||
head_num = model.cfg.num_attention_heads | ||
num_layers = model.cfg.num_layers | ||
ffn_hidden_size = model.cfg.ffn_hidden_size | ||
num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B | ||
|
||
head_size = hidden_size // head_num | ||
heads_per_group = head_num // num_query_groups | ||
qkv_total_dim = head_num + 2 * num_query_groups | ||
|
||
# Embedding | ||
embed_weight = model.state_dict()[f'model.embedding.word_embeddings.weight'] | ||
embed_weights_base_name = f'model.embed_tokens.weight' | ||
checkpoint['state_dict'][embed_weights_base_name] = param_to_weights(embed_weight) | ||
|
||
for l in range(int(num_layers)): | ||
print(f"converting layer {l}") | ||
|
||
qkv_weights = model.state_dict()[f'model.decoder.layers.{l}.self_attention.linear_qkv.weight'] | ||
qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size]) | ||
|
||
q_slice = torch.cat( | ||
[ | ||
torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) | ||
for i in range(num_query_groups) | ||
] | ||
) | ||
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) | ||
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) | ||
## Example of slices | ||
## 7b: num_query_groups = head_num = 32, | ||
## q_slice = [0, 3, 6, 9 , ... 90, 93] | ||
## k_slice = [1, 4, 7, 10, ... 91, 94] | ||
## v_slice = [2, 5, 8, 11, ... 92, 95] | ||
## 70b (with GQA): num_query_groups = 8, head_num = 64 | ||
## q_slice = [0, 1, .. 6, 7, 10, 11, .. 16, 17, 20, 21, .. 67, 70, ... 76, 77] | ||
## k_slice = [8, 18, 28, ... 68, 78] | ||
## v_slice = [9, 19, 29, ... 69, 79] | ||
|
||
q_weights_base_name = f'model.layers.{l}.self_attn.q_proj.weight' | ||
k_weights_base_name = f'model.layers.{l}.self_attn.k_proj.weight' | ||
v_weights_base_name = f'model.layers.{l}.self_attn.v_proj.weight' | ||
|
||
checkpoint['state_dict'][q_weights_base_name] = param_to_weights(qkv_weights[q_slice].reshape(-1, hidden_size)) | ||
checkpoint['state_dict'][k_weights_base_name] = param_to_weights(qkv_weights[k_slice].reshape(-1, hidden_size)) | ||
checkpoint['state_dict'][v_weights_base_name] = param_to_weights(qkv_weights[v_slice].reshape(-1, hidden_size)) | ||
|
||
# attention dense | ||
o_weight = model.state_dict()[f'model.decoder.layers.{l}.self_attention.linear_proj.weight'] | ||
o_weight_base_name = f'model.layers.{l}.self_attn.o_proj.weight' | ||
checkpoint['state_dict'][o_weight_base_name] = param_to_weights(o_weight) | ||
|
||
# mlp | ||
mlp_weights = model.state_dict()[f'model.decoder.layers.{l}.mlp.linear_fc1.weight'] | ||
mlp_down_proj_weight = mlp_weights[:ffn_hidden_size, :] | ||
mlp_gate_proj_weight = mlp_weights[ffn_hidden_size:, :] | ||
|
||
mlp_down_proj_base_name = f'model.layers.{l}.mlp.gate_proj.weight' | ||
mlp_gate_proj_base_name = f'model.layers.{l}.mlp.up_proj.weight' | ||
|
||
checkpoint['state_dict'][mlp_down_proj_base_name] = param_to_weights(mlp_down_proj_weight) | ||
checkpoint['state_dict'][mlp_gate_proj_base_name] = param_to_weights(mlp_gate_proj_weight) | ||
|
||
mlp_up_proj_weight = model.state_dict()[f'model.decoder.layers.{l}.mlp.linear_fc2.weight'] | ||
mlp_up_proj_base_name = f'model.layers.{l}.mlp.down_proj.weight' | ||
checkpoint['state_dict'][mlp_up_proj_base_name] = param_to_weights(mlp_up_proj_weight) | ||
|
||
# layernorm | ||
input_ln_weight = model.state_dict()[f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight'] | ||
input_ln_base_name = f'model.layers.{l}.input_layernorm.weight' | ||
checkpoint['state_dict'][input_ln_base_name] = param_to_weights(input_ln_weight) | ||
|
||
post_attn_ln_weight = model.state_dict()[f'model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight'] | ||
post_attn_ln_base_name = f'model.layers.{l}.post_attention_layernorm.weight' | ||
checkpoint['state_dict'][post_attn_ln_base_name] = param_to_weights(post_attn_ln_weight) | ||
|
||
print(f"done layer {l}") | ||
|
||
final_ln_weight = model.state_dict()[f'model.decoder.final_layernorm.weight'] | ||
final_ln_base_name = f'model.norm.weight' | ||
checkpoint['state_dict'][final_ln_base_name] = param_to_weights(final_ln_weight) | ||
|
||
output_layer_weight = model.state_dict()[f'model.output_layer.weight'] | ||
output_layer_base_name = f'lm_head.weight' | ||
checkpoint['state_dict'][output_layer_base_name] = param_to_weights(output_layer_weight) | ||
|
||
os.makedirs(os.path.dirname(output_hf_file), exist_ok=True) | ||
torch.save(checkpoint, output_hf_file) | ||
logging.info(f"Weights saved to {output_hf_file}") | ||
|
||
|
||
def replace_hf_weights(weights_file, input_hf_path, output_hf_path): | ||
model = AutoModelForCausalLM.from_pretrained(input_hf_path, local_files_only=True) | ||
nemo_exported = torch.load(weights_file) | ||
|
||
model.load_state_dict(nemo_exported['state_dict']) | ||
model.save_pretrained(output_hf_path) | ||
logging.info(f"Full HF model saved to {output_hf_path}") | ||
|
||
|
||
if __name__ == '__main__': | ||
args = get_args() | ||
convert(args.in_file, args.out_file, precision=args.precision, cpu_only=args.cpu_only) | ||
if args.hf_in_path and args.hf_out_path: | ||
replace_hf_weights(args.out_file, args.hf_in_path, args.hf_out_path) | ||
else: | ||
logging.info("`hf-in-path` and/or `hf-out-path` not provided, not generating full HF model.") | ||
logging.info(f".bin file is saved to {args.out_file}") |