diff --git a/Jenkinsfile b/Jenkinsfile index 5d81a57c04c9..636d1b519f25 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -423,6 +423,14 @@ pipeline { sh 'rm -f /home/TestData/nlp/megatron_gpt/falcon-ci-hf/falcon_ci.nemo' } } + stage('Baichuan2') { + steps { + sh 'python scripts/nlp_language_modeling/convert_hf_baichuan2_to_nemo.py \ + --in-file=/home/TestData/nlp/megatron_gpt/Baichuan2-7B-Base \ + --out-file=/home/TestData/nlp/megatron_gpt/Baichuan2-7B-Base/ci.nemo' + sh 'rm -f /home/TestData/nlp/megatron_gpt/Baichuan2-7B-Base/ci.nemo' + } + } } } diff --git a/examples/nlp/language_modeling/conf/megatron_baichuan2_config.yaml b/examples/nlp/language_modeling/conf/megatron_baichuan2_config.yaml new file mode 100644 index 000000000000..6f9077334fb3 --- /dev/null +++ b/examples/nlp/language_modeling/conf/megatron_baichuan2_config.yaml @@ -0,0 +1,225 @@ +name: megatron_baichuan2 +restore_from_path: null # used when starting from a .nemo file + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: 32 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice, max_steps will be reached first. + max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + val_check_interval: 100 + limit_val_batches: 50 + limit_test_batches: 500 + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: False # default PTL callback for this does not support model parallelism, instead we log manually + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: megatron_baichuan2 + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 10 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + save_nemo_on_train_end: False # not recommended when training large models on clusters with short time limits + filename: 'megatron_gpt--{val_loss:.2f}-{step}-{consumed_samples}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + +model: + mcore_gpt: True + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 4 # limited by GPU memory + global_batch_size: 8 # will use more micro batches to reach global batch size + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + virtual_pipeline_model_parallel_size: null # interleaved pipeline + + # model architecture + encoder_seq_length: 4096 + max_position_embeddings: ${.encoder_seq_length} + num_layers: 32 # 7b: 32 | 13b: 40 + hidden_size: 4096 # 7b: 4096 | 13b: 5120 + ffn_hidden_size: 11008 # Transformer FFN hidden size. Usually 4 * hidden_size. | 7b: 11008 | 13b: 13696 + num_attention_heads: 32 # 7b: 32 | 13b: 40 + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + use_scaled_init_method: True # use scaled residuals initialization + hidden_dropout: 0.0 # Dropout probability for hidden state transformer. + attention_dropout: 0.0 # Dropout probability for attention + ffn_dropout: 0.0 # Dropout probability in the feed-forward layer. + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: 'rmsnorm' # Normalization layer to use. Options are 'layernorm', 'rmsnorm' + layernorm_epsilon: 1e-6 + do_layer_norm_weight_decay: False # True means weight decay on all params + make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. + pre_process: True # add embedding + post_process: True # add pooler + persist_layer_norm: True # Use of persistent fused layer norm kernel. + bias: False # Whether to use bias terms in all weight matrices. + activation: 'swiglu' # Options ['gelu', 'geglu', 'swiglu', 'reglu', 'squared-relu', 'fast-geglu', 'fast-swiglu', 'fast-reglu'] + headscale: False # Whether to learn extra parameters that scale the output of the each self-attention head. + transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer'] + openai_gelu: False # Use OpenAI's GELU instead of the default GeLU + normalize_attention_scores: True # Whether to scale the output Q * K^T by 1 / sqrt(hidden_size_per_head). This arg is provided as a configuration option mostly for compatibility with models that have been weight-converted from HF. You almost always want to se this to True. + position_embedding_type: 'rope' # Position embedding type. Options ['learned_absolute', 'rope']. | 7b: 'rope' | 13b: 'alibi' + rotary_percentage: 1.0 # If using position_embedding_type=rope, then the per head dim is multiplied by this. + attention_type: 'multihead' # Attention type. Options ['multihead'] + share_embeddings_and_output_weights: False # Share embedding and output layer weights. + overlap_p2p_comm: False # Overlap p2p communication with computes. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1 + batch_p2p_comm: True # Batch consecutive inter-peer send/recv operations. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1 + num_query_groups: 32 # Number of query groups for group query attention. If None, normal attention is used. | 7b: 32 | 13b: 40 + + tokenizer: + library: 'sentencepiece' + type: null + model: ??? # /path/to/tokenizer.model + vocab_file: null + merge_file: null + delimiter: null # only used for tabular tokenizer + sentencepiece_legacy: False # Legacy=True allows you to add special tokens to sentencepiece tokenizers. + trust_remote_code: True + + # Mixed precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + hysteresis: 2 # Gradient scale hysteresis + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # Megatron O2-style half-precision + megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters # False!!! + grad_allreduce_chunk_size_mb: 125 + + # Fusion + grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce. Only used with O2 and no pipeline parallelism.. + gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism and O2. + bias_activation_fusion: False # Use a kernel that fuses the bias addition from weight matrices with the subsequent activation function. + bias_dropout_add_fusion: False # Use a kernel that fuses the bias addition, dropout and residual connection addition. + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + get_attention_mask_from_fusion: True # When using fused softmax it will create the attention mask so we won't copy it to the pipeline stages. + + + # Miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + sync_batch_comm: False # Enable stream synchronization after each p2p communication between pipeline stages + + ## Activation Checkpointing + # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. + # These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + # 'full' will checkpoint the entire transformer layer. + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model. + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null + # when using 'uniform' this creates groups of transformer layers to checkpoint. Usually set to 1. Increase to save more memory. + # when using 'block' this this will checkpoint the first activations_checkpoint_num_layers per pipeline stage. + num_micro_batches_with_partial_activation_checkpoints: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value is provided, it sets the number of micro-batches where only a partial number of Transformer layers get checkpointed + # and recomputed within a window of micro-batches. The rest of micro-batches in the window checkpoint all Transformer layers. The size of window is + # set by the maximum outstanding micro-batch backpropagations, which varies at different pipeline stages. The number of partial layers to checkpoint + # per micro-batch is set by 'activations_checkpoint_num_layers' with 'activations_checkpoint_method' of 'block'. + # This feature enables using activation checkpoint at a fraction of micro-batches up to the point of full GPU memory usage. + activations_checkpoint_layers_per_pipeline: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value (rounded down when float is given) is provided, it sets the number of Transformer layers to skip checkpointing at later + # pipeline stages. For example, 'activations_checkpoint_layers_per_pipeline' of 3 makes pipeline stage 1 to checkpoint 3 layers less than + # stage 0 and stage 2 to checkpoint 6 layers less stage 0, and so on. This is possible because later pipeline stage + # uses less GPU memory with fewer outstanding micro-batch backpropagations. Used with 'num_micro_batches_with_partial_activation_checkpoints', + # this feature removes most of activation checkpoints at the last pipeline stage, which is the critical execution path. + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Transformer Engine + transformer_engine: True + fp8: False # enables fp8 in TransformerLayer forward + fp8_e4m3: False # sets fp8_format = recipe.Format.E4M3 + fp8_hybrid: False # sets fp8_format = recipe.Format.HYBRID + fp8_margin: 0 # scaling margin + fp8_interval: 1 # scaling update interval + fp8_amax_history_len: 1 # Number of steps for which amax history is recorded per tensor + fp8_amax_compute_algo: most_recent # 'most_recent' or 'max'. Algorithm for computing amax from history + reduce_amax: True # Perform reduction to sync amax tensors across GPUs after every iteration + use_emha: False # Use fused multi-head attention for large sequence-length. Note this is not yet supported. Please set to False. + + data: + # Path to data must be specified by the user. + # Supports List, String and Dictionary + # List : can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-gpt3_00_text_document,.5,/raid/data/pile/my-gpt3_01_text_document]", + # Or see example below: + # data_prefix: + # - .5 + # - /raid/data/pile/my-gpt3_00_text_document + # - .5 + # - /raid/data/pile/my-gpt3_01_text_document + # Dictionary: can override from CLI "model.data.data_prefix"={"train":[1.0, /path/to/data], "validation":/path/to/data, "test":/path/to/test} + # Or see example below: + # "model.data.data_prefix: {train:[1.0,/path/to/data], validation:[/path/to/data], test:[/path/to/test]}" + # data_prefix: ??? + index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix + data_impl: mmap + splits_string: 900,50,50 + seq_length: ${model.encoder_seq_length} + skip_warmup: True + num_workers: 2 + dataloader_type: single # cyclic + reset_position_ids: False # Reset position ids after end-of-document token + reset_attention_mask: False # Reset attention mask after end-of-document token + eod_mask_loss: False # Mask loss for the end of document tokens + validation_drop_last: True # Set to false if the last partial validation samples is to be consumed + no_seqlen_plus_one_input_tokens: False # Set to True to disable fetching (sequence length + 1) input tokens, instead get (sequence length) input tokens and mask the last token + pad_samples_to_global_batch_size: False # Set to True if you want to pad the last partial batch with -1's to equal global batch size + shuffle_documents: True # Set to False to disable documents shuffling. Sample index will still be shuffled + + # Nsys profiling options + nsys_profile: + enabled: False + start_step: 10 # Global batch to start profiling + end_step: 10 # Global batch to end profiling + ranks: [0] # Global rank IDs to profile + gen_shape: False # Generate model and kernel details including input shapes + + optim: + name: distributed_fused_adam + bucket_cap_mb: 100 + overlap_grad_sync: True + overlap_param_sync: True + contiguous_grad_buffer: True + grad_sync_dtype: bf16 + lr: 2e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 500 + constant_steps: 50000 + min_lr: 2e-5 diff --git a/examples/nlp/language_modeling/conf/megatron_baichuan2_inference.yaml b/examples/nlp/language_modeling/conf/megatron_baichuan2_inference.yaml new file mode 100644 index 000000000000..f2d9d6ab56ba --- /dev/null +++ b/examples/nlp/language_modeling/conf/megatron_baichuan2_inference.yaml @@ -0,0 +1,39 @@ +inference: + greedy: False # Whether or not to use sampling ; use greedy decoding otherwise + top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: 1.0 # sampling temperature + add_BOS: False # add the bos token at the begining of the prompt + tokens_to_generate: 30 # The minimum length of the sequence to be generated. + all_probs: False # whether return the log prob for all the tokens in vocab + repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. + min_tokens_to_generate: 0 # The minimum length of the sequence to be generated. + compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False + end_strings: [""] # generation will stop when one of these tokens is generated + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + logger: False # logger provided by exp_manager + precision: bf16 # 16, 32, or bf16 + use_distributed_sampler: False + +tensor_model_parallel_size: -1 +pipeline_model_parallel_size: -1 +pipeline_model_parallel_split_rank: -1 # used for encoder and decoder model (0 for others) +megatron_amp_O2: False # Enable O2-level automatic mixed precision to save memory +gpt_model_file: null # GPT nemo file path +checkpoint_dir: null # checkpoint file dir. This is used to load the PTL checkpoint generated during the GPT training +checkpoint_name: null # PTL checkpoint file name, only used for PTL checkpoint loading +hparams_file: null # model configuration file, only used for PTL checkpoint loading +prompts: # prompts for GPT inference + - "Q: How are you?" + - "Q: How big is the universe?" +server: False # whether launch the API server +port: 5555 # the port number for the inference server +web_server: False # whether launch the web inference server +share: False # whether create a public URL +username: test # user name for web client +password: test2 # password for web client +web_port: 9889 # the port number of the web server diff --git a/scripts/nlp_language_modeling/convert_hf_baichuan2_to_nemo.py b/scripts/nlp_language_modeling/convert_hf_baichuan2_to_nemo.py new file mode 100644 index 000000000000..217c24fd85c1 --- /dev/null +++ b/scripts/nlp_language_modeling/convert_hf_baichuan2_to_nemo.py @@ -0,0 +1,316 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Conversion script to convert Huggingface Baichuan2 checkpoints into nemo checkpoint. + Example to run this conversion script: + python convert_hf_baichuan2_to_nemo.py \ + --in-file \ + --out-file +""" + +import os +from argparse import ArgumentParser +from collections import OrderedDict + +import torch +from omegaconf import OmegaConf +from pytorch_lightning.trainer.trainer import Trainer +from transformers import AutoModelForCausalLM, AutoTokenizer + + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.parts.nlp_overrides import ( + GradScaler, + MegatronHalfPrecisionPlugin, + NLPDDPStrategy, + NLPSaveRestoreConnector, + PipelineMixedPrecisionPlugin, +) +from nemo.collections.nlp.parts.utils_funcs import load_state_dict_helper, torch_dtype_from_precision +from nemo.utils import logging + + +def get_args(): + parser = ArgumentParser() + parser.add_argument( + "--in-file", type=str, default=None, required=True, help="Path to Huggingface baichuan2 checkpoints", + ) + parser.add_argument("--out-file", type=str, default=None, required=True, help="Path to output .nemo file.") + parser.add_argument("--precision", type=str, default="32", help="Model precision") + args = parser.parse_args() + return args + + +def load_model(cls, checkpoint, strict, **kwargs): + try: + if 'cfg' in kwargs: + model = ptl_load_state(cls, checkpoint, strict=strict, **kwargs) + else: + model = cls(cfg=checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY], **kwargs) + for name, module in model.named_parameters(): + if name in checkpoint['state_dict']: + module.data = checkpoint['state_dict'][name] + checkpoint['state_dict'].pop(name) + else: + print(f"Unexpected key: {name} not in checkpoint but in model.") + + for name, buffer in model.named_buffers(): + if name in checkpoint['state_dict']: + buffer.data = checkpoint['state_dict'][name] + checkpoint['state_dict'].pop(name) + + if len(checkpoint['state_dict'].keys()) != 0: + raise RuntimeError( + f"Additional keys: {checkpoint['state_dict'].keys()} in checkpoint but not in model." + ) + + # register the artifacts + cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] + if cfg.tokenizer.model is not None: + model.register_artifact("tokenizer.tokenizer_model", cfg.tokenizer.model) + if cfg.tokenizer.vocab_file is not None: + model.register_artifact("tokenizer.vocab_file", cfg.tokenizer.vocab_file) + if cfg.tokenizer.merge_file is not None: + model.register_artifact("tokenizer.merge_file", cfg.tokenizer.merge_file) + finally: + cls._set_model_restore_state(is_being_restored=False) + return model + + +def load_config(args, baichuan2_config): + nemo_config = OmegaConf.load( + os.path.join( + os.path.dirname(__file__), '../../examples/nlp/language_modeling/conf/megatron_baichuan2_config.yaml' + ) + ).model + if 'max_position_embeddings' in baichuan2_config: + nemo_config.encoder_seq_length = baichuan2_config['max_position_embeddings'] + nemo_config.num_layers = baichuan2_config['num_hidden_layers'] + nemo_config.hidden_size = baichuan2_config['hidden_size'] + nemo_config.ffn_hidden_size = baichuan2_config['intermediate_size'] + nemo_config.num_attention_heads = baichuan2_config['num_attention_heads'] + nemo_config.num_query_groups = baichuan2_config['num_attention_heads'] + nemo_config.init_method_std = baichuan2_config['initializer_range'] + nemo_config.layernorm_epsilon = baichuan2_config['rms_norm_eps'] + nemo_config.use_cpu_initialization = True + nemo_config.activation = 'fast-swiglu' + nemo_config.tokenizer.model = baichuan2_config['tokenizer_model'] + if nemo_config.num_layers == 32: + nemo_config.position_embedding_type = 'rope' + elif nemo_config.num_layers == 40: + nemo_config.position_embedding_type = 'alibi' + + base = 128 + while baichuan2_config['vocab_size'] % base != 0: + base //= 2 + nemo_config.make_vocab_size_divisible_by = base + + return nemo_config + + +def convert(args): + logging.info(f"loading checkpoint {args.in_file}") + model = AutoModelForCausalLM.from_pretrained(args.in_file, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(args.in_file, trust_remote_code=True) + hf_config = vars(model.config) + hf_config['tokenizer_model'] = str(tokenizer.vocab_file) + print(f"hf_config: {hf_config}") + print("named parameters:") + for name, param in model.named_parameters(): + print(f"- {name}") + + nemo_config = load_config(args, hf_config) + + if args.precision in ["32", "16"]: + precision = int(float(args.precision)) + elif args.precision in ["bf16", "bf16-mixed"]: + if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + precision = args.precision + else: + logging.warning("BF16 is not supported on this device. Using FP16 instead.") + precision = args.precision[2:] # prune bf in string + else: + precision = args.precision + + plugins = [] + if precision in [16, '16', 'bf16', '16-mixed', 'bf16-mixed']: + scaler = None + if precision in [16, '16', '16-mixed']: + scaler = GradScaler( + init_scale=nemo_config.get('native_amp_init_scale', 2 ** 32), + growth_interval=nemo_config.get('native_amp_growth_interval', 1000), + hysteresis=nemo_config.get('hysteresis', 2), + ) + # MixedPrecisionPlugin in PTL >= 2.0 requires precision to be 16-mixed or bf16-mixed + plugin_precision = '16-mixed' + else: + plugin_precision = 'bf16-mixed' + + if nemo_config.get('megatron_amp_O2', False): + plugins.append(MegatronHalfPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) + else: + plugins.append(PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) + + nemo_config.precision = precision + print(f"nemo_config: {nemo_config}") + + trainer = Trainer(plugins=plugins, accelerator='cpu', precision=precision, strategy=NLPDDPStrategy()) + + hidden_size = hf_config["hidden_size"] + head_num = hf_config["num_attention_heads"] + head_size = hidden_size // head_num + num_layers = hf_config["num_hidden_layers"] + + mcore_gpt = nemo_config.mcore_gpt + + assert mcore_gpt == nemo_config.get( + 'transformer_engine', False + ), "mcore_gpt transformer_engine must be enabled (or disabled) together." + + param_to_weights = lambda param: param.float() + + checkpoint = OrderedDict() + checkpoint['state_dict'] = OrderedDict() + + embed_weight = model.state_dict()[f'model.embed_tokens.weight'] + if mcore_gpt: + embed_weights_base_name = f'model.embedding.word_embeddings.weight' + else: + embed_weights_base_name = f'model.language_model.embedding.word_embeddings.weight' + checkpoint['state_dict'][embed_weights_base_name] = param_to_weights(embed_weight) + + # in hf, this is defined as register_buffer(..., persistent=False) so it won't be in the state dict + if f'model.layers.0.self_attn.rotary_emb.inv_freq' in model.state_dict(): + rotary_embed_weight = model.state_dict()[f'model.layers.0.self_attn.rotary_emb.inv_freq'] + if mcore_gpt: + rotary_embed_weight_base_name = f'model.rotary_pos_emb.inv_freq' + else: + rotary_embed_weight_base_name = f'model.language_model.rotary_pos_emb.inv_freq' + checkpoint['state_dict'][rotary_embed_weight_base_name] = param_to_weights(rotary_embed_weight) + + if nemo_config.num_query_groups is None or nemo_config.num_query_groups == head_num: + num_query_groups = head_num + else: + num_query_groups = nemo_config.num_query_groups + assert head_num % num_query_groups == 0, 'head_num must be divisible by num_query_groups' + if mcore_gpt: + assert nemo_config.activation.startswith('fast-'), 'mcore only supports fast version of gated linear unit.' + + for l in range(int(num_layers)): + print(f"converting layer {l}") + qkv_weights = model.state_dict()[f'model.layers.{l}.self_attn.W_pack.weight'] + qkv_weights = qkv_weights.unflatten(0, (3, hidden_size)) + old_tensor_shape = qkv_weights[0].squeeze().size() + new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] + new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] + q = qkv_weights[0].squeeze().view(*new_q_tensor_shape) + k = qkv_weights[1].squeeze().view(*new_kv_tensor_shape) + v = qkv_weights[2].squeeze().view(*new_kv_tensor_shape) + qkv_weights = torch.empty((0, head_size) + old_tensor_shape[1:]) + heads_per_group = head_num // num_query_groups + for i in range(num_query_groups): + qkv_weights = torch.cat((qkv_weights, q[i * heads_per_group : (i + 1) * heads_per_group, :, :])) + qkv_weights = torch.cat((qkv_weights, k[i : i + 1, :, :])) + qkv_weights = torch.cat((qkv_weights, v[i : i + 1, :, :])) + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + if mcore_gpt: + qkv_weights_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.weight' + else: + qkv_weights_base_name = f'model.language_model.encoder.layers.{l}.self_attention.query_key_value.weight' + checkpoint['state_dict'][qkv_weights_base_name] = param_to_weights(qkv_weights) + + # attention dense + o_weight = model.state_dict()[f'model.layers.{l}.self_attn.o_proj.weight'] + if mcore_gpt: + o_weight_base_name = f'model.decoder.layers.{l}.self_attention.linear_proj.weight' + else: + o_weight_base_name = f'model.language_model.encoder.layers.{l}.self_attention.dense.weight' + checkpoint['state_dict'][o_weight_base_name] = param_to_weights(o_weight) + + # MLP + mlp_down_weight = model.state_dict()[f'model.layers.{l}.mlp.gate_proj.weight'] + mlp_gate_weight = model.state_dict()[f'model.layers.{l}.mlp.up_proj.weight'] + if mcore_gpt: + mlp_down_base_name = f'model.decoder.layers.{l}.mlp.linear_fc1.weight' + else: + mlp_down_base_name = f'model.language_model.encoder.layers.{l}.mlp.dense_h_to_4h.weight' + mlp_down_weight = torch.cat((mlp_down_weight, mlp_gate_weight), axis=0) + checkpoint['state_dict'][mlp_down_base_name] = param_to_weights(mlp_down_weight) + + mlp_up_weight = model.state_dict()[f'model.layers.{l}.mlp.down_proj.weight'] + if mcore_gpt: + mlp_up_base_name = f'model.decoder.layers.{l}.mlp.linear_fc2.weight' + else: + mlp_up_base_name = f'model.language_model.encoder.layers.{l}.mlp.dense_4h_to_h.weight' + checkpoint['state_dict'][mlp_up_base_name] = param_to_weights(mlp_up_weight) + + # LayerNorm + input_ln_weight = model.state_dict()[f'model.layers.{l}.input_layernorm.weight'] + if mcore_gpt: + input_ln_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight' + else: + input_ln_base_name = f'model.language_model.encoder.layers.{l}.input_layernorm.weight' + checkpoint['state_dict'][input_ln_base_name] = param_to_weights(input_ln_weight) + + post_attn_ln_weight = model.state_dict()[f'model.layers.{l}.post_attention_layernorm.weight'] + if mcore_gpt: + post_attn_ln_base_name = f'model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight' + else: + post_attn_ln_base_name = f'model.language_model.encoder.layers.{l}.post_attention_layernorm.weight' + checkpoint['state_dict'][post_attn_ln_base_name] = param_to_weights(post_attn_ln_weight) + + print(f"done layer {l}") + + final_ln_weight = model.state_dict()[f'model.norm.weight'] + if mcore_gpt: + final_ln_base_name = f'model.decoder.final_layernorm.weight' + else: + final_ln_base_name = f'model.language_model.encoder.final_layernorm.weight' + checkpoint['state_dict'][final_ln_base_name] = param_to_weights(final_ln_weight) + + output_layer_weight = model.state_dict()[f'lm_head.weight'] + if mcore_gpt: + output_layer_base_name = f'model.output_layer.weight' + else: + output_layer_base_name = f'model.language_model.output_layer.weight' + checkpoint['state_dict'][output_layer_base_name] = param_to_weights(output_layer_weight) + + checkpoint[MegatronGPTModel.CHECKPOINT_HYPER_PARAMS_KEY] = nemo_config + + del model + + if nemo_config.get('megatron_amp_O2', False): + keys = list(checkpoint['state_dict'].keys()) + for key in keys: + checkpoint['state_dict'][key.replace('model.', 'model.module.', 1)] = checkpoint['state_dict'].pop(key) + + model = load_state_dict_helper(MegatronGPTModel, nemo_config, trainer, checkpoint['state_dict']) + + model._save_restore_connector = NLPSaveRestoreConnector() + + # cast to target precision and disable cpu init + dtype = torch_dtype_from_precision(precision) + model = model.to(dtype=dtype) + model.cfg.use_cpu_initialization = False + + model.save_to(args.out_file) + logging.info(f'NeMo model saved to: {args.out_file}') + + +if __name__ == '__main__': + args = get_args() + convert(args) diff --git a/scripts/nlp_language_modeling/convert_nemo_baichuan2_to_hf.py b/scripts/nlp_language_modeling/convert_nemo_baichuan2_to_hf.py new file mode 100644 index 000000000000..41379574882e --- /dev/null +++ b/scripts/nlp_language_modeling/convert_nemo_baichuan2_to_hf.py @@ -0,0 +1,221 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from argparse import ArgumentParser +from collections import OrderedDict + +import torch +from pytorch_lightning import Trainer +from transformers import AutoModelForCausalLM + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy +from nemo.utils import logging + +""" +Script to convert a baichuan2 checkpoint in nemo (mcore path) into a HuggingFace checkpoint. +This script can be used to 1) generate only the HF weights, or 2) generate an entire HF model folder. + +1) Generate only HF weights from a nemo file: + + python convert_nemo_baichuan2_to_hf.py \ + --in-file /path/to/file.nemo or /path/to/extracted_folder \ + --out-file /path/to/pytorch_model.bin + +2) Generate the full HF model folder + + python convert_nemo_baichuan2_to_hf.py \ + --in-file /path/to/file.nemo or /path/to/extracted_folder \ + --out-file /path/to/pytorch_model.bin \ + --hf-in-path /path/to/input_hf_folder \ + --hf-out-path /path/to/output_hf_folder + + Use the --cpu-only flag if the model cannot fit in the GPU. + However this option makes the conversion script significantly slower. +""" + + +def get_args(): + parser = ArgumentParser() + parser.add_argument( + "--in-file", type=str, default=None, required=True, help="Path to .nemo file", + ) + parser.add_argument("--out-file", type=str, default=None, required=True, help="Path to HF .bin file") + parser.add_argument( + "--hf-in-path", + type=str, + default=None, + help="A HF model path, " + "e.g. a folder containing https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/tree/main", + ) + parser.add_argument( + "--hf-out-path", + type=str, + default=None, + help="Output HF model path, " "with the same format as above but user's own weights", + ) + parser.add_argument( + "--precision", + type=str, + default=None, + help="Precision of output weights." + "Defaults to precision of the input nemo weights (model.cfg.trainer.precision)", + ) + parser.add_argument( + "--cpu-only", + action="store_true", + help="Load model in cpu only. Useful if the model cannot fit in GPU memory, " + "but this option makes the conversion script significantly slower.", + ) + args = parser.parse_args() + return args + + +def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> None: + """ + Convert NeMo weights to HF weights + """ + + dummy_trainer = Trainer(devices=1, accelerator='cpu', strategy=NLPDDPStrategy()) + if cpu_only: + map_location = torch.device('cpu') + model_config = MegatronGPTModel.restore_from(input_nemo_file, trainer=dummy_trainer, return_config=True) + model_config.use_cpu_initialization = True + model_config.tensor_model_parallel_size = 1 + else: + map_location, model_config = None, None + + if cpu_only: + logging.info("******** Loading model on CPU. This will take a significant amount of time.") + model = MegatronGPTModel.restore_from( + input_nemo_file, trainer=dummy_trainer, override_config_path=model_config, map_location=map_location + ) + if precision is None: + precision = model.cfg.precision + if precision in [32, "32"]: + dtype = torch.float32 + elif precision in [16, "16", "16-mixed"]: + dtype = torch.float16 + elif precision in ["bf16", "bf16-mixed"]: + dtype = torch.bfloat16 + else: + logging.warning(f"Precision string {precision} is not recognized, falling back to fp32") + dtype = torch.float32 # fallback + + param_to_weights = lambda param: param.to(dtype) + checkpoint = OrderedDict() + + hidden_size = model.cfg.hidden_size + head_num = model.cfg.num_attention_heads + num_layers = model.cfg.num_layers + ffn_hidden_size = model.cfg.ffn_hidden_size + num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B + + head_size = hidden_size // head_num + heads_per_group = head_num // num_query_groups + qkv_total_dim = head_num + 2 * num_query_groups + + # Embedding + embed_weight = model.state_dict()[f'model.embedding.word_embeddings.weight'] + embed_weights_base_name = f'model.embed_tokens.weight' + checkpoint[embed_weights_base_name] = param_to_weights(embed_weight) + + for l in range(int(num_layers)): + print(f"converting layer {l}") + + qkv_weights = model.state_dict()[f'model.decoder.layers.{l}.self_attention.linear_qkv.weight'] + qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size]) + + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + qkv_weights_base_name = f'model.layers.{l}.self_attn.W_pack.weight' + checkpoint[qkv_weights_base_name] = param_to_weights( + torch.cat( + [ + qkv_weights[q_slice].reshape(-1, hidden_size), + qkv_weights[k_slice].reshape(-1, hidden_size), + qkv_weights[v_slice].reshape(-1, hidden_size), + ] + ) + ) + + # attention dense + o_weight = model.state_dict()[f'model.decoder.layers.{l}.self_attention.linear_proj.weight'] + o_weight_base_name = f'model.layers.{l}.self_attn.o_proj.weight' + checkpoint[o_weight_base_name] = param_to_weights(o_weight) + + # mlp + mlp_weights = model.state_dict()[f'model.decoder.layers.{l}.mlp.linear_fc1.weight'] + mlp_down_proj_weight = mlp_weights[:ffn_hidden_size, :] + mlp_gate_proj_weight = mlp_weights[ffn_hidden_size:, :] + + mlp_down_proj_base_name = f'model.layers.{l}.mlp.gate_proj.weight' + mlp_gate_proj_base_name = f'model.layers.{l}.mlp.up_proj.weight' + + checkpoint[mlp_down_proj_base_name] = param_to_weights(mlp_down_proj_weight) + checkpoint[mlp_gate_proj_base_name] = param_to_weights(mlp_gate_proj_weight) + + mlp_up_proj_weight = model.state_dict()[f'model.decoder.layers.{l}.mlp.linear_fc2.weight'] + mlp_up_proj_base_name = f'model.layers.{l}.mlp.down_proj.weight' + checkpoint[mlp_up_proj_base_name] = param_to_weights(mlp_up_proj_weight) + + # layernorm + input_ln_weight = model.state_dict()[f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight'] + input_ln_base_name = f'model.layers.{l}.input_layernorm.weight' + checkpoint[input_ln_base_name] = param_to_weights(input_ln_weight) + + post_attn_ln_weight = model.state_dict()[f'model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight'] + post_attn_ln_base_name = f'model.layers.{l}.post_attention_layernorm.weight' + checkpoint[post_attn_ln_base_name] = param_to_weights(post_attn_ln_weight) + + print(f"done layer {l}") + + final_ln_weight = model.state_dict()[f'model.decoder.final_layernorm.weight'] + final_ln_base_name = f'model.norm.weight' + checkpoint[final_ln_base_name] = param_to_weights(final_ln_weight) + + output_layer_weight = model.state_dict()[f'model.output_layer.weight'] + output_layer_base_name = f'lm_head.weight' + checkpoint[output_layer_base_name] = param_to_weights(output_layer_weight) + + os.makedirs(os.path.dirname(output_hf_file), exist_ok=True) + torch.save(checkpoint, output_hf_file) + logging.info(f"Weights saved to {output_hf_file}") + + +def replace_hf_weights(weights_file, input_hf_path, output_hf_path): + model = AutoModelForCausalLM.from_pretrained(input_hf_path, local_files_only=True, trust_remote_code=True) + nemo_exported = torch.load(weights_file) + + model.load_state_dict(nemo_exported) + model.save_pretrained(output_hf_path) + logging.info(f"Full HF model saved to {output_hf_path}") + + +if __name__ == '__main__': + args = get_args() + convert(args.in_file, args.out_file, precision=args.precision, cpu_only=args.cpu_only) + if args.hf_in_path and args.hf_out_path: + replace_hf_weights(args.out_file, args.hf_in_path, args.hf_out_path) + else: + logging.info("`hf-in-path` and/or `hf-out-path` not provided, not generating full HF model.") + logging.info(f".bin file is saved to {args.out_file}")