Skip to content

Commit

Permalink
Mistral 7b conversion script (NVIDIA#8052)
Browse files Browse the repository at this point in the history
* Import script for mistral-7b.

From mistral checkpoint not hf.
Pending: support for block-diagonal attention mask.

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* add window_size to nemo_config.

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Switch from Mistral checkpoint to HF-Mistral.

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Force lowercase when checking for normalization type.

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* NeMo-Mistral-7B to HF-Mistral-7B.

Signed-off-by: Alexandros Koumparoulis <[email protected]>

---------

Signed-off-by: Alexandros Koumparoulis <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <[email protected]>
Signed-off-by: Sasha Meister <[email protected]>
  • Loading branch information
3 people authored and sashameister committed Feb 15, 2024
1 parent 4df1606 commit 8ec00c8
Show file tree
Hide file tree
Showing 3 changed files with 567 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -1596,7 +1596,7 @@ def build_transformer_config(self) -> TransformerConfig:
For attributes in TransformerConfig that are not in the nemo model config, we add custom logic.
"""

normalization = self.cfg.get('normalization', 'layernorm')
normalization = self.cfg.get('normalization', 'layernorm').lower()
layernorm_zero_centered_gamma = self.cfg.get('normalization', 'layernorm') == 'layernorm1p'
if normalization == 'layernorm':
normalization = 'LayerNorm'
Expand Down
341 changes: 341 additions & 0 deletions scripts/nlp_language_modeling/convert_hf_mistral_7b_to_nemo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,341 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""
Conversion script to convert HuggingFace Mistral-7B checkpoints into nemo checkpoint.
Example to run this conversion script:
python convert_hf_mistral_7b_to_nemo.py \
--in-file <path_to_mistral_checkpoints_folder> \
--out-file <path_to_output_nemo_file> \
[--fast-swiglu\
"""


import json
import os
from argparse import ArgumentParser
from collections import OrderedDict

import torch
import torch.nn
from omegaconf import OmegaConf
from pytorch_lightning.core.saving import _load_state as ptl_load_state
from pytorch_lightning.trainer.trainer import Trainer
from sentencepiece import SentencePieceProcessor

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.parts.nlp_overrides import (
GradScaler,
MegatronHalfPrecisionPlugin,
NLPDDPStrategy,
NLPSaveRestoreConnector,
PipelineMixedPrecisionPlugin,
)
from nemo.utils import logging


def get_args():
parser = ArgumentParser()
parser.add_argument(
"--in-file", type=str, default=None, required=True, help="Path to Huggingface Mistral-7b checkpoints",
)
parser.add_argument("--out-file", type=str, default=None, required=True, help="Path to output .nemo file.")
parser.add_argument("--precision", type=str, default="32", help="Model precision")
args = parser.parse_args()
return args


def load_model(cls, checkpoint, strict, **kwargs):
try:
if 'cfg' in kwargs:
model = ptl_load_state(cls, checkpoint, strict=strict, **kwargs)
else:
model = cls(cfg=checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY], **kwargs)
for name, module in model.named_parameters():
if name in checkpoint['state_dict']:
module.data = checkpoint['state_dict'][name]
checkpoint['state_dict'].pop(name)
else:
print(f"Unexpected key: {name} not in checkpoint but in model.")

for name, buffer in model.named_buffers():
if name in checkpoint['state_dict']:
buffer.data = checkpoint['state_dict'][name]
checkpoint['state_dict'].pop(name)

if len(checkpoint['state_dict'].keys()) != 0:
raise RuntimeError(
f"Additional keys: {checkpoint['state_dict'].keys()} in checkpoint but not in model."
)

# register the artifacts
cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]
if cfg.tokenizer.model is not None:
model.register_artifact("tokenizer.tokenizer_model", cfg.tokenizer.model)
if cfg.tokenizer.vocab_file is not None:
model.register_artifact("tokenizer.vocab_file", cfg.tokenizer.vocab_file)
if cfg.tokenizer.merge_file is not None:
model.register_artifact("tokenizer.merge_file", cfg.tokenizer.merge_file)
finally:
cls._set_model_restore_state(is_being_restored=False)
return model


def load_config(mistral_config, tokenizer_path):
nemo_config = OmegaConf.load(
os.path.join(os.path.dirname(__file__), '../../examples/nlp/language_modeling/conf/megatron_llama_config.yaml')
).model
# akoumparouli: verify this.
nemo_config.encoder_seq_length = mistral_config['sliding_window']
nemo_config.num_layers = int(mistral_config['num_hidden_layers'])
nemo_config.hidden_size = mistral_config['hidden_size']
nemo_config.ffn_hidden_size = mistral_config['intermediate_size']
nemo_config.num_attention_heads = mistral_config['num_attention_heads']
nemo_config.max_position_embeddings = mistral_config['max_position_embeddings']
nemo_config.window_size = [mistral_config['sliding_window'], 0]
nemo_config.init_method_std = mistral_config['initializer_range']
# RMSNorm's epsilon.
nemo_config.layernorm_epsilon = mistral_config['rms_norm_eps']
nemo_config.normalization = 'rmsnorm'

if 'num_key_value_heads' in mistral_config:
nemo_config.num_query_groups = mistral_config['num_key_value_heads']
nemo_config.use_cpu_initialization = True
# Mistral uses SiLU, but it is the same as swish with beta = 1.
nemo_config.activation = 'fast-swiglu'

nemo_config.tokenizer.model = tokenizer_path
# TODO(@akoumparouli): rope_scaling.
nemo_config['rotary_base'] = mistral_config['rope_theta']

base = 128
while mistral_config['vocab_size'] % base != 0:
base //= 2
nemo_config.make_vocab_size_divisible_by = base

return nemo_config


def load_mistral_ckpt(dir):
params_file = os.path.join(dir, 'config.json')
assert os.path.exists(params_file)
with open(params_file, 'r') as fp:
model_args = json.load(fp)

ckpt = OrderedDict()
ckpt['state_dict'] = OrderedDict()
for i in range(2):
ckpt_file = f'pytorch_model-0000{i+1}-of-00002.bin'
ckpt_path = os.path.join(dir, ckpt_file)
assert os.path.exists(ckpt_path)
ckpt.update(torch.load(ckpt_path))
tokenizer_file = os.path.join(dir, 'tokenizer.model')
assert os.path.exists(tokenizer_file)
tokenizer = SentencePieceProcessor(model_file=tokenizer_file)
assert tokenizer.get_piece_size() == model_args['vocab_size']
return model_args, ckpt, tokenizer


def convert(args):
logging.info(f"loading checkpoint {args.in_file}")

model_args, ckpt, tokenizer = load_mistral_ckpt(args.in_file)
nemo_config = load_config(model_args, os.path.join(args.in_file, 'tokenizer.model'))
logging.info(f"loaded checkpoint {args.in_file}")

if args.precision in ["32", "16"]:
precision = int(float(args.precision))
elif args.precision in ["bf16", "bf16-mixed"]:
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
precision = args.precision
else:
logging.warning("BF16 is not supported on this device. Using FP16 instead.")
precision = args.precision[2:] # prune bf in string
else:
precision = args.precision

plugins = []
if precision in [16, '16', 'bf16', '16-mixed', 'bf16-mixed']:
scaler = None
if precision in [16, '16', '16-mixed']:
scaler = GradScaler(
init_scale=nemo_config.get('native_amp_init_scale', 2 ** 32),
growth_interval=nemo_config.get('native_amp_growth_interval', 1000),
hysteresis=nemo_config.get('hysteresis', 2),
)
# MixedPrecisionPlugin in PTL >= 2.0 requires precision to be 16-mixed or bf16-mixed
plugin_precision = '16-mixed'
else:
plugin_precision = 'bf16-mixed'

if nemo_config.get('megatron_amp_O2', False):
plugins.append(MegatronHalfPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler))
else:
plugins.append(PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler))

if precision == 32:
dtype = torch.float32
elif precision in [16, "16", "16-mixed"]:
dtype = torch.float16
elif precision in ["bf16", "bf16-mixed"]:
dtype = torch.bfloat16
else:
dtype = torch.float32 # fallback

nemo_config.precision = precision
logging.info(f"nemo_config: {nemo_config}")

trainer = Trainer(plugins=plugins, accelerator='cpu', precision=precision, strategy=NLPDDPStrategy())

hidden_size = nemo_config.hidden_size
head_num = nemo_config.num_attention_heads
head_size = hidden_size // head_num
num_layers = nemo_config.num_layers

mcore_gpt = nemo_config.mcore_gpt

assert mcore_gpt == nemo_config.get(
'transformer_engine', False
), "mcore_gpt transformer_engine must be enabled (or disabled) together."

param_to_weights = lambda param: param.float()

checkpoint = OrderedDict()
checkpoint['state_dict'] = OrderedDict()

embed_weight = ckpt[f'model.embed_tokens.weight']
if mcore_gpt:
embed_weights_base_name = f'model.embedding.word_embeddings.weight'
else:
embed_weights_base_name = f'model.language_model.embedding.word_embeddings.weight'
checkpoint['state_dict'][embed_weights_base_name] = param_to_weights(embed_weight)

if nemo_config.num_query_groups is None or nemo_config.num_query_groups == head_num:
num_query_groups = head_num
else:
num_query_groups = nemo_config.num_query_groups
assert head_num % num_query_groups == 0, 'head_num must be divisible by num_query_groups'
if mcore_gpt:
assert nemo_config.activation.startswith('fast-'), 'mcore only supports fast version of gated linear unit.'

for l in range(int(num_layers)):
print(f"converting layer {l}")
old_tensor_shape = ckpt[f'model.layers.{l}.self_attn.q_proj.weight'].size()
new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:]
new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:]

q = ckpt[f'model.layers.{l}.self_attn.q_proj.weight'].view(*new_q_tensor_shape)
k = ckpt[f'model.layers.{l}.self_attn.k_proj.weight'].view(*new_kv_tensor_shape)
v = ckpt[f'model.layers.{l}.self_attn.v_proj.weight'].view(*new_kv_tensor_shape)

# Note: we assume wq & wk have been appropriately transposed to work with
# NeMo/Megatron's rotary embedding. The reference checkpoint/implementation
# will not work OotB without transposing wq/wk matrices.
heads_per_group = head_num // num_query_groups
qkv_weights_l = []
for i in range(num_query_groups):
qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :])
qkv_weights_l.append(k[i : i + 1, :, :])
qkv_weights_l.append(v[i : i + 1, :, :])
qkv_weights = torch.cat(qkv_weights_l)
assert qkv_weights.ndim == 3, qkv_weights.shape
assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape
assert qkv_weights.shape[1] == head_size, qkv_weights.shape
assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape
qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size])
if mcore_gpt:
qkv_weights_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.weight'
else:
qkv_weights_base_name = f'model.language_model.encoder.layers.{l}.self_attention.query_key_value.weight'
checkpoint['state_dict'][qkv_weights_base_name] = param_to_weights(qkv_weights)

# attention dense
o_weight = ckpt[f'model.layers.{l}.self_attn.o_proj.weight']
if mcore_gpt:
o_weight_base_name = f'model.decoder.layers.{l}.self_attention.linear_proj.weight'
else:
o_weight_base_name = f'model.language_model.encoder.layers.{l}.self_attention.dense.weight'
checkpoint['state_dict'][o_weight_base_name] = param_to_weights(o_weight)

# MLP
mlp_down_weight = ckpt[f'model.layers.{l}.mlp.gate_proj.weight']
mlp_gate_weight = ckpt[f'model.layers.{l}.mlp.up_proj.weight']
if mcore_gpt:
mlp_down_base_name = f'model.decoder.layers.{l}.mlp.linear_fc1.weight'
else:
mlp_down_base_name = f'model.language_model.encoder.layers.{l}.mlp.dense_h_to_4h.weight'
mlp_down_weight = torch.cat((mlp_down_weight, mlp_gate_weight), axis=0)
checkpoint['state_dict'][mlp_down_base_name] = param_to_weights(mlp_down_weight)

mlp_up_weight = ckpt[f'model.layers.{l}.mlp.down_proj.weight']
if mcore_gpt:
mlp_up_base_name = f'model.decoder.layers.{l}.mlp.linear_fc2.weight'
else:
mlp_up_base_name = f'model.language_model.encoder.layers.{l}.mlp.dense_4h_to_h.weight'
checkpoint['state_dict'][mlp_up_base_name] = param_to_weights(mlp_up_weight)

# LayerNorm
input_ln_weight = ckpt[f'model.layers.{l}.input_layernorm.weight']
if mcore_gpt:
input_ln_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight'
else:
input_ln_base_name = f'model.language_model.encoder.layers.{l}.input_layernorm.weight'
checkpoint['state_dict'][input_ln_base_name] = param_to_weights(input_ln_weight)

post_attn_ln_weight = ckpt[f'model.layers.{l}.post_attention_layernorm.weight']
if mcore_gpt:
post_attn_ln_base_name = f'model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight'
else:
post_attn_ln_base_name = f'model.language_model.encoder.layers.{l}.post_attention_layernorm.weight'
checkpoint['state_dict'][post_attn_ln_base_name] = param_to_weights(post_attn_ln_weight)

print(f"done layer {l}")

final_ln_weight = ckpt[f'model.norm.weight']
if mcore_gpt:
final_ln_base_name = f'model.decoder.final_layernorm.weight'
else:
final_ln_base_name = f'model.language_model.encoder.final_layernorm.weight'
checkpoint['state_dict'][final_ln_base_name] = param_to_weights(final_ln_weight)

output_layer_weight = ckpt[f'lm_head.weight']
if mcore_gpt:
output_layer_base_name = f'model.output_layer.weight'
else:
output_layer_base_name = f'model.language_model.output_layer.weight'
checkpoint['state_dict'][output_layer_base_name] = param_to_weights(output_layer_weight)

checkpoint[MegatronGPTModel.CHECKPOINT_HYPER_PARAMS_KEY] = nemo_config
del ckpt

if nemo_config.get('megatron_amp_O2', False):
keys = list(checkpoint['state_dict'].keys())
for key in keys:
checkpoint['state_dict'][key.replace('model.', 'model.module.', 1)] = checkpoint['state_dict'].pop(key)

model = load_model(MegatronGPTModel, checkpoint, strict=False, trainer=trainer)

model._save_restore_connector = NLPSaveRestoreConnector()

# cast to target precision and disable cpu init
model = model.to(dtype=dtype)
model.cfg.use_cpu_initialization = False

model.save_to(args.out_file)
logging.info(f'NeMo model saved to: {args.out_file}')


if __name__ == '__main__':
args = get_args()
convert(args)
Loading

0 comments on commit 8ec00c8

Please sign in to comment.