Skip to content

Commit

Permalink
Akoumparouli/low mem mixtral ckpt converter (#8895)
Browse files Browse the repository at this point in the history
* add --low-mem option to enable conversion of large checkpoints with low ram requirements

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* delete param_to_weights

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* various fixes; set hf dtype to auto

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* remove unused lien

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Alexandros Koumparoulis <[email protected]>
Signed-off-by: Alexandros Koumparoulis <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
akoumpa and pre-commit-ci[bot] committed Apr 16, 2024
1 parent 12e7cf9 commit 468d5b6
Showing 1 changed file with 88 additions and 21 deletions.
109 changes: 88 additions & 21 deletions scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import os
from argparse import ArgumentParser
from collections import OrderedDict
from pathlib import Path

import megatron.core.parallel_state as parallel_state
import torch
Expand All @@ -43,6 +44,8 @@
)
from nemo.utils import logging

torch.set_grad_enabled(False)


def get_args():
parser = ArgumentParser()
Expand All @@ -51,6 +54,8 @@ def get_args():
)
parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.")
parser.add_argument("--precision", type=str, default="32", help="Model precision")
parser.add_argument('--low-ram', action='store_true')
parser.add_argument('--tmp-dir', default='/tmp/mixtral_ckpt_parts/')
args = parser.parse_args()
return args

Expand Down Expand Up @@ -108,6 +113,9 @@ def load_config(mixtral_config, tokenizer_path):
# RMSNorm's epsilon.
nemo_config.layernorm_epsilon = mixtral_config['rms_norm_eps']
nemo_config.normalization = 'rmsnorm'
nemo_config.micro_batch_size = 1
nemo_config.global_batch_size = 1
nemo_config.expert_model_parallel_size = 1

if 'num_key_value_heads' in mixtral_config:
nemo_config.num_query_groups = mixtral_config['num_key_value_heads']
Expand All @@ -132,24 +140,28 @@ def load_config(mixtral_config, tokenizer_path):
return nemo_config


def load_mixtral_ckpt(in_dir):
def load_hf_model_args(in_dir):
params_file = os.path.join(in_dir, 'config.json')
assert os.path.exists(params_file)
with open(params_file, 'r') as fp:
model_args = json.load(fp)
return model_args


model = AutoModelForCausalLM.from_pretrained(in_dir)
ckpt = model.state_dict()
def load_mixtral_ckpt(in_dir, load_model=True):
model_args = load_hf_model_args(in_dir)
ckpt = None
if load_model:
model = AutoModelForCausalLM.from_pretrained(in_dir, torch_dtype='auto')
ckpt = model.state_dict()

tokenizer = AutoTokenizer.from_pretrained(in_dir)
assert tokenizer.vocab_size == model_args['vocab_size']
return model_args, ckpt, tokenizer


def convert(args):
logging.info(f"loading checkpoint {args.input_name_or_path}")

model_args, ckpt, tokenizer = load_mixtral_ckpt(args.input_name_or_path)
def make_trainer(args, nemo_config):
model_args, ckpt, tokenizer = load_mixtral_ckpt(args.input_name_or_path, load_model=False)
nemo_config = load_config(model_args, tokenizer.vocab_file)

if args.precision in ["32", "16"]:
Expand Down Expand Up @@ -195,6 +207,14 @@ def convert(args):
print(f"nemo_config: {nemo_config}")

trainer = Trainer(plugins=plugins, accelerator='cpu', strategy=NLPDDPStrategy())
return trainer, dtype


def convert(args):
logging.info(f"loading checkpoint {args.input_name_or_path}")

model_args, ckpt, tokenizer = load_mixtral_ckpt(args.input_name_or_path)
nemo_config = load_config(model_args, tokenizer.vocab_file)

hidden_size = nemo_config.hidden_size
head_num = nemo_config.num_attention_heads
Expand All @@ -207,8 +227,6 @@ def convert(args):
'transformer_engine', False
), "mcore_gpt transformer_engine must be enabled (or disabled) together."

param_to_weights = lambda param: param.float()

checkpoint = OrderedDict()
checkpoint['state_dict'] = OrderedDict()

Expand All @@ -217,7 +235,7 @@ def convert(args):
embed_weights_base_name = f'model.embedding.word_embeddings.weight'
else:
embed_weights_base_name = f'model.language_model.embedding.word_embeddings.weight'
checkpoint['state_dict'][embed_weights_base_name] = param_to_weights(embed_weight)
checkpoint['state_dict'][embed_weights_base_name] = embed_weight

if nemo_config.num_query_groups is None or nemo_config.num_query_groups == head_num:
num_query_groups = head_num
Expand All @@ -227,6 +245,10 @@ def convert(args):
if mcore_gpt:
assert nemo_config.activation.startswith('fast-'), 'mcore only supports fast version of gated linear unit.'

yield checkpoint
checkpoint = OrderedDict()
checkpoint['state_dict'] = OrderedDict()

for l in range(int(num_layers)):
print(f"converting layer {l}")
old_tensor_shape = ckpt[f'model.layers.{l}.self_attn.q_proj.weight'].size()
Expand All @@ -249,15 +271,15 @@ def convert(args):
qkv_weights_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.weight'
else:
qkv_weights_base_name = f'model.language_model.encoder.layers.{l}.self_attention.query_key_value.weight'
checkpoint['state_dict'][qkv_weights_base_name] = param_to_weights(qkv_weights)
checkpoint['state_dict'][qkv_weights_base_name] = qkv_weights

# attention dense
o_weight = ckpt[f'model.layers.{l}.self_attn.o_proj.weight']
if mcore_gpt:
o_weight_base_name = f'model.decoder.layers.{l}.self_attention.linear_proj.weight'
else:
o_weight_base_name = f'model.language_model.encoder.layers.{l}.self_attention.dense.weight'
checkpoint['state_dict'][o_weight_base_name] = param_to_weights(o_weight)
checkpoint['state_dict'][o_weight_base_name] = o_weight

# # MLP
# Handle gate
Expand All @@ -266,7 +288,7 @@ def convert(args):
moe_gate_name = f'model.decoder.layers.{l}.mlp.router.weight'
else:
raise Exception("not implemented")
checkpoint['state_dict'][moe_gate_name] = param_to_weights(moe_gate)
checkpoint['state_dict'][moe_gate_name] = moe_gate
# Handle experts
for i in range(nemo_config.num_moe_experts):
gate_proj = ckpt[f'model.layers.{l}.block_sparse_moe.experts.{i}.w1.weight']
Expand All @@ -276,14 +298,14 @@ def convert(args):
else:
raise Exception("not implemented")
mlp_down_weight = torch.cat((gate_proj, up_proj), axis=0)
checkpoint['state_dict'][mlp_down_base_name] = param_to_weights(mlp_down_weight)
checkpoint['state_dict'][mlp_down_base_name] = mlp_down_weight

mlp_up_weight = ckpt[f'model.layers.{l}.block_sparse_moe.experts.{i}.w2.weight']
if mcore_gpt:
mlp_up_base_name = f'model.decoder.layers.{l}.mlp.experts.local_experts.{i}.linear_fc2.weight'
else:
raise Exception("not implemented")
checkpoint['state_dict'][mlp_up_base_name] = param_to_weights(mlp_up_weight)
checkpoint['state_dict'][mlp_up_base_name] = mlp_up_weight

# LayerNorm
input_ln_weight = ckpt[f'model.layers.{l}.input_layernorm.weight']
Expand All @@ -292,7 +314,7 @@ def convert(args):
input_ln_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight'
else:
input_ln_base_name = f'model.language_model.encoder.layers.{l}.input_layernorm.weight'
checkpoint['state_dict'][input_ln_base_name] = param_to_weights(input_ln_weight)
checkpoint['state_dict'][input_ln_base_name] = input_ln_weight

post_attn_ln_weight = ckpt[f'model.layers.{l}.post_attention_layernorm.weight']
if mcore_gpt:
Expand All @@ -301,28 +323,57 @@ def convert(args):
post_attn_ln_base_name = f'model.decoder.layers.{l}.pre_mlp_layernorm.weight'
else:
post_attn_ln_base_name = f'model.language_model.encoder.layers.{l}.post_attention_layernorm.weight'
checkpoint['state_dict'][post_attn_ln_base_name] = param_to_weights(post_attn_ln_weight)
checkpoint['state_dict'][post_attn_ln_base_name] = post_attn_ln_weight

print(f"done layer {l}")

yield checkpoint
checkpoint = OrderedDict()
checkpoint['state_dict'] = OrderedDict()

final_ln_weight = ckpt[f'model.norm.weight']
if mcore_gpt:
final_ln_base_name = f'model.decoder.final_layernorm.weight'
else:
final_ln_base_name = f'model.language_model.encoder.final_layernorm.weight'
checkpoint['state_dict'][final_ln_base_name] = param_to_weights(final_ln_weight)
checkpoint['state_dict'][final_ln_base_name] = final_ln_weight

output_layer_weight = ckpt[f'lm_head.weight']
if mcore_gpt:
output_layer_base_name = f'model.output_layer.weight'
else:
output_layer_base_name = f'model.language_model.output_layer.weight'
checkpoint['state_dict'][output_layer_base_name] = param_to_weights(output_layer_weight)
checkpoint['state_dict'][output_layer_base_name] = output_layer_weight

checkpoint[MegatronGPTModel.CHECKPOINT_HYPER_PARAMS_KEY] = nemo_config

yield checkpoint
del ckpt


def merge(a: dict, b: dict, path=[]):
is_dict = lambda x: isinstance(x, OrderedDict) or isinstance(x, dict)
for key in b:
if key in a:
if is_dict(a[key]) and is_dict(b[key]):
merge(a[key], b[key], path + [str(key)])
elif a[key] != b[key]:
raise Exception('Value conflict: ' + '.'.join(path + [str(key)]))
else:
a[key] = b[key]
return a


def save_to_nemo(args, checkpoint):

logging.info(f"loading checkpoint {args.input_name_or_path}")
model_args, ckpt, tokenizer = load_mixtral_ckpt(args.input_name_or_path, load_model=False)
nemo_config = load_config(model_args, tokenizer.vocab_file)
trainer, dtype = make_trainer(args, nemo_config)

checkpoint[MegatronGPTModel.CHECKPOINT_HYPER_PARAMS_KEY] = nemo_config
checkpoint[MegatronGPTModel.CHECKPOINT_HYPER_PARAMS_KEY].use_cpu_initialization = True
checkpoint[MegatronGPTModel.CHECKPOINT_HYPER_PARAMS_KEY].perform_initialization = False

if nemo_config.get('megatron_amp_O2', False):
keys = list(checkpoint['state_dict'].keys())
for key in keys:
Expand All @@ -342,5 +393,21 @@ def convert(args):

if __name__ == '__main__':
args = get_args()
if args.low_ram:
os.makedirs(args.tmp_dir, exist_ok=True)

parallel_state.set_expert_model_parallel_world_size(1)
convert(args)
checkpoint = OrderedDict()
for i, ckpt_part in enumerate(convert(args)):
if args.low_ram:
torch.save(ckpt_part, f'{args.tmp_dir}/nemo_ckpt_part_{i}.pth')
else:
checkpoint = merge(checkpoint, ckpt_part)

if args.low_ram:
print("Loading partial checkpoints")
for path in map(str, Path(args.tmp_dir).rglob("*.pth")):
print(f"Loading checkpoint: {path}")
checkpoint = merge(checkpoint, torch.load(path, mmap=True))

save_to_nemo(args, checkpoint)

0 comments on commit 468d5b6

Please sign in to comment.