diff --git a/Jenkinsfile b/Jenkinsfile index 045c53c4f7d1..f2e0704a0ea5 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -143,6 +143,15 @@ pipeline { sh 'rm -f /home/TestData/nlp/megatron_gpt/starcoder-ci-hf/megatron_starcoder_tp1_pp1.nemo' } } + stage('Falcon') { + steps { + sh 'python scripts/nlp_language_modeling/convert_hf_falcon_to_nemo.py \ + --config examples/nlp/language_modeling/conf/megatron_falcon_config.yaml \ + --input /home/TestData/nlp/megatron_gpt/falcon-ci-hf \ + --output /home/TestData/nlp/megatron_gpt/falcon-ci-hf/falcon_ci.nemo' + sh 'rm -f /home/TestData/nlp/megatron_gpt/falcon-ci-hf/falcon_ci.nemo' + } + } } } diff --git a/examples/nlp/language_modeling/conf/megatron_falcon_config.yaml b/examples/nlp/language_modeling/conf/megatron_falcon_config.yaml new file mode 100644 index 000000000000..4b8009256a9e --- /dev/null +++ b/examples/nlp/language_modeling/conf/megatron_falcon_config.yaml @@ -0,0 +1,219 @@ +name: megatron_falcon_gpt +restore_from_path: null # used when starting from a .nemo file + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice, max_steps will be reached first. + max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + val_check_interval: 100 + limit_val_batches: 50 + limit_test_batches: 500 + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: False # default PTL callback for this does not support model parallelism, instead we log manually + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: megatron_falcon_gpt + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 10 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + save_nemo_on_train_end: False # not recommended when training large models on clusters with short time limits + filename: 'megatron_falcon--{val_loss:.2f}-{step}-{consumed_samples}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + +model: + mcore_gpt: True + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 1 # limited by GPU memory + global_batch_size: 1 # will use more micro batches to reach global batch size + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + virtual_pipeline_model_parallel_size: null # interleaved pipeline + + # model architecture + encoder_seq_length: 2048 + max_position_embeddings: ${.encoder_seq_length} + num_layers: 32 # 7b: 32 | 40b: 60 | 180b: 80 + hidden_size: 4544 # 7b: 4544 | 40b: 8192 | 180b: 14848 + ffn_hidden_size: 18176 # Transformer FFN hidden size. Usually 4 * hidden_size. | 7b: 18176 | 40b: 32768 | 180b: 59392 + num_attention_heads: 71 # 7b: 71 | 40b: 128 | 180b: 232 + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + use_scaled_init_method: True # use scaled residuals initialization + hidden_dropout: 0.0 # Dropout probability for hidden state transformer. + attention_dropout: 0.0 # Dropout probability for attention + ffn_dropout: 0.0 # Dropout probability in the feed-forward layer. + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: 'layernorm' # Normalization layer to use. Options are 'layernorm', 'rmsnorm' + layernorm_epsilon: 1e-5 + do_layer_norm_weight_decay: False # True means weight decay on all params + make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. + pre_process: True # add embedding + post_process: True # add pooler + persist_layer_norm: True # Use of persistent fused layer norm kernel. + bias: False # Whether to use bias terms in all weight matrices. + activation: 'gelu' # Options ['gelu', 'geglu', 'swiglu', 'reglu', 'squared-relu', 'fast-geglu', 'fast-swiglu', 'fast-reglu'] + headscale: False # Whether to learn extra parameters that scale the output of the each self-attention head. + transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer'] + openai_gelu: False # Use OpenAI's GELU instead of the default GeLU + normalize_attention_scores: True # Whether to scale the output Q * K^T by 1 / sqrt(hidden_size_per_head). This arg is provided as a configuration option mostly for compatibility with models that have been weight-converted from HF. You almost always want to se this to True. + position_embedding_type: 'rope' # Position embedding type. Options ['learned_absolute', 'rope'] + rotary_percentage: 1.0 # If using position_embedding_type=rope, then the per head dim is multiplied by this. + attention_type: 'multihead' # Attention type. Options ['multihead'] + share_embeddings_and_output_weights: False # Share embedding and output layer weights. + overlap_p2p_comm: False # Overlap p2p communication with computes. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1 + batch_p2p_comm: True # Batch consecutive inter-peer send/recv operations. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1 + num_query_groups: 1 # Number of query groups for group query attention. If None, normal attention is used. | 7b: 1 | 40b: 8 | 180b: 8 + gc_interval: 0 + precision: bf16 + mcore_customization_config: + new_decoder_architecture: false + parallel_attention: true + + tokenizer: + library: 'huggingface' + type: 'tiiuae/falcon-7b' + use_fast: True + + # Mixed precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + hysteresis: 2 # Gradient scale hysteresis + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # Megatron O2-style half-precision + megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters + grad_allreduce_chunk_size_mb: 125 + + # Fusion + grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce. Only used with O2 and no pipeline parallelism.. + gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism and O2. + bias_activation_fusion: False # Use a kernel that fuses the bias addition from weight matrices with the subsequent activation function. + bias_dropout_add_fusion: False # Use a kernel that fuses the bias addition, dropout and residual connection addition. + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + get_attention_mask_from_fusion: True # When using fused softmax it will create the attention mask so we won't copy it to the pipeline stages. + + + # Miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + sync_batch_comm: False # Enable stream synchronization after each p2p communication between pipeline stages + + ## Activation Checkpointing + # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. + # These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + # 'full' will checkpoint the entire transformer layer. + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model. + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null + # when using 'uniform' this creates groups of transformer layers to checkpoint. Usually set to 1. Increase to save more memory. + # when using 'block' this this will checkpoint the first activations_checkpoint_num_layers per pipeline stage. + num_micro_batches_with_partial_activation_checkpoints: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value is provided, it sets the number of micro-batches where only a partial number of Transformer layers get checkpointed + # and recomputed within a window of micro-batches. The rest of micro-batches in the window checkpoint all Transformer layers. The size of window is + # set by the maximum outstanding micro-batch backpropagations, which varies at different pipeline stages. The number of partial layers to checkpoint + # per micro-batch is set by 'activations_checkpoint_num_layers' with 'activations_checkpoint_method' of 'block'. + # This feature enables using activation checkpoint at a fraction of micro-batches up to the point of full GPU memory usage. + activations_checkpoint_layers_per_pipeline: null + # This feature is valid only when used with pipeline-model-parallelism. + # When an integer value (rounded down when float is given) is provided, it sets the number of Transformer layers to skip checkpointing at later + # pipeline stages. For example, 'activations_checkpoint_layers_per_pipeline' of 3 makes pipeline stage 1 to checkpoint 3 layers less than + # stage 0 and stage 2 to checkpoint 6 layers less stage 0, and so on. This is possible because later pipeline stage + # uses less GPU memory with fewer outstanding micro-batch backpropagations. Used with 'num_micro_batches_with_partial_activation_checkpoints', + # this feature removes most of activation checkpoints at the last pipeline stage, which is the critical execution path. + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Transformer Engine + fp8: False # enables fp8 in TransformerLayer forward + fp8_e4m3: False # sets fp8_format = recipe.Format.E4M3 + fp8_hybrid: False # sets fp8_format = recipe.Format.HYBRID + fp8_margin: 0 # scaling margin + fp8_interval: 1 # scaling update interval + fp8_amax_history_len: 1 # Number of steps for which amax history is recorded per tensor + fp8_amax_compute_algo: most_recent # 'most_recent' or 'max'. Algorithm for computing amax from history + reduce_amax: True # Perform reduction to sync amax tensors across GPUs after every iteration + use_emha: False # Use fused multi-head attention for large sequence-length. Note this is not yet supported. Please set to False. + + data: + # Path to data must be specified by the user. + # Supports List, String and Dictionary + # List : can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-gpt3_00_text_document,.5,/raid/data/pile/my-gpt3_01_text_document]", + # Or see example below: + # data_prefix: + # - .5 + # - /raid/data/pile/my-gpt3_00_text_document + # - .5 + # - /raid/data/pile/my-gpt3_01_text_document + # Dictionary: can override from CLI "model.data.data_prefix"={"train":[1.0, /path/to/data], "validation":/path/to/data, "test":/path/to/test} + # Or see example below: + # "model.data.data_prefix: {train:[1.0,/path/to/data], validation:[/path/to/data], test:[/path/to/test]}" + # data_prefix: ??? + index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix + data_impl: mmap + splits_string: 900,50,50 + seq_length: ${model.encoder_seq_length} + skip_warmup: True + num_workers: 2 + dataloader_type: single # cyclic + reset_position_ids: False # Reset position ids after end-of-document token + reset_attention_mask: False # Reset attention mask after end-of-document token + eod_mask_loss: False # Mask loss for the end of document tokens + validation_drop_last: True # Set to false if the last partial validation samples is to be consumed + no_seqlen_plus_one_input_tokens: False # Set to True to disable fetching (sequence length + 1) input tokens, instead get (sequence length) input tokens and mask the last token + pad_samples_to_global_batch_size: False # Set to True if you want to pad the last partial batch with -1's to equal global batch size + shuffle_documents: True # Set to False to disable documents shuffling. Sample index will still be shuffled + + # Nsys profiling options + nsys_profile: + enabled: False + start_step: 10 # Global batch to start profiling + end_step: 10 # Global batch to end profiling + ranks: [0] # Global rank IDs to profile + gen_shape: False # Generate model and kernel details including input shapes + + optim: + name: distributed_fused_adam + lr: 2e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 500 + constant_steps: 50000 + min_lr: 2e-5 \ No newline at end of file diff --git a/examples/nlp/language_modeling/conf/megatron_falcon_inference.yaml b/examples/nlp/language_modeling/conf/megatron_falcon_inference.yaml new file mode 100644 index 000000000000..1ccc9ed5dff8 --- /dev/null +++ b/examples/nlp/language_modeling/conf/megatron_falcon_inference.yaml @@ -0,0 +1,38 @@ +inference: + greedy: False # Whether or not to use sampling ; use greedy decoding otherwise + top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: 1.0 # sampling temperature + add_BOS: False # add the bos token at the begining of the prompt + tokens_to_generate: 30 # The minimum length of the sequence to be generated. + all_probs: False # whether return the log prob for all the tokens in vocab + repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. + min_tokens_to_generate: 0 # The minimum length of the sequence to be generated. + compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False + end_strings: ["<|endoftext|>"] # generation will stop when one of these tokens is generated + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + logger: False # logger provided by exp_manager + precision: bf16 # 16, 32, or bf16 + use_distributed_sampler: False + +tensor_model_parallel_size: 1 +pipeline_model_parallel_size: 1 +megatron_amp_O2: True # Enable O2-level automatic mixed precision to save memory +gpt_model_file: null # GPT nemo file path +checkpoint_dir: null # checkpoint file dir. This is used to load the PTL checkpoint generated during the GPT training +checkpoint_name: null # PTL checkpoint file name, only used for PTL checkpoint loading +hparams_file: null # model configuration file, only used for PTL checkpoint loading +prompts: # prompts for GPT inference + - "Q: How are you?" + - "Q: How big is the universe?" +server: False # whether launch the API server +port: 5555 # the port number for the inference server +web_server: False # whether launch the web inference server +share: False # whether create a public URL +username: test # user name for web client +password: test2 # password for web client +web_port: 9889 # the port number of the web server diff --git a/nemo/collections/nlp/models/language_modeling/megatron/falcon/__init__.py b/nemo/collections/nlp/models/language_modeling/megatron/falcon/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron/falcon/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_decoder_layer.py b/nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_decoder_layer.py new file mode 100644 index 000000000000..67c732c6aee2 --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_decoder_layer.py @@ -0,0 +1,220 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults + +try: + from megatron.core import parallel_state + from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensor + from megatron.core.transformer.enums import AttnMaskType + from megatron.core.transformer.spec_utils import build_module + from megatron.core.transformer.transformer_config import TransformerConfig + from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + from megatron.core.utils import make_viewless_tensor + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + + TransformerLayer = ApexGuardDefaults + TransformerConfig = ApexGuardDefaults + TransformerLayerSubmodules = ApexGuardDefaults + AttnMaskType = ApexGuardDefaults() + +""" We use the following notation throughout this file: + h: hidden size + n: number of attention heads + p: number of model parallel partitions + np: n/p + hp: h/p + hn: h/n + b: batch size + s: sequence length + l: number of layers + Transformer takes input of size [s, b, h] and returns a + tensor of the same size. We use the following arguments: + hyperparameters: transformer hyperparameters +""" + + +class FalconTransformerLayer(TransformerLayer): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + + """ + + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + self_attn_mask_type=AttnMaskType.padding, + ): + if not HAVE_MEGATRON_CORE: + raise ImportError( + "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + ) + super().__init__(config=config, submodules=submodules, layer_number=layer_number) + + if hasattr(self.config, 'new_decoder_architecture'): + self.new_decoder_architecture = self.config.new_decoder_architecture + else: + self.new_decoder_architecture = None + if hasattr(self.config, 'parallel_attention'): + self.parallel_attention = self.config.parallel_attention + else: + self.parallel_attention = None + + if self.new_decoder_architecture or self.parallel_attention: + self.post_self_attn_layernorm = None + else: + self.post_self_attn_layernorm = build_module( + submodules.post_self_attn_layernorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + if self.new_decoder_architecture: + self.pre_mlp_layernorm = build_module( + submodules.pre_mlp_layernorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + else: + self.pre_mlp_layernorm = None + + def forward( + self, + hidden_states, + attention_mask, + context=None, + context_mask=None, + rotary_pos_emb=None, + inference_params=None, + ): + # hidden_states: [s, b, h] + + # Residual connection. + residual = hidden_states + + mlp_ln_output = None + if self.new_decoder_architecture: + mlp_ln_output = self.pre_mlp_layernorm(hidden_states) + + input_layernorm_output = self.input_layernorm(hidden_states) + + input_mlp_ln = input_layernorm_output + + # Self attention. + attention_output_with_bias = self.self_attention( + input_layernorm_output, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + ) + + with self.bias_dropout_add_exec_handler(): + hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)( + attention_output_with_bias, residual, self.config.hidden_dropout + ) + + if not self.new_decoder_architecture: + if self.parallel_attention: + layernorm_output = input_mlp_ln + else: + residual = hidden_states + layernorm_output = self.post_self_attn_layernorm(hidden_states) + + else: + layernorm_output = mlp_ln_output + + mlp_output_with_bias = self.mlp(layernorm_output) + + # falcon specific: + if self.new_decoder_architecture or self.parallel_attention: + mlp_output = mlp_output_with_bias[0] + attn_output = attention_output_with_bias[0] + mlp_output_without_bias = mlp_output + attn_output + mlp_output_with_bias = (mlp_output_without_bias, None) + + with self.bias_dropout_add_exec_handler(): + hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)( + mlp_output_with_bias, residual, self.config.hidden_dropout + ) + + output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True) + + return output, context + + def sharded_state_dict(self, prefix=''): + + state_dict = self.state_dict(keep_vars=True) + + tensor_parallel_layers_axis_map = { + 'self_attention.linear_qkv.weight': 0, + 'self_attention.linear_qkv.bias': 0, + 'self_attention.linear_proj.weight': 1, + 'mlp.linear_fc1.weight': 0, + 'mlp.linear_fc1.bias': 0, + 'mlp.linear_fc2.weight': 1, + } + + offset = self._get_layer_offset() + num_layers = self.config.num_layers + + sharded_state_dict = {} + + for layer_name in state_dict.keys(): + tensor = state_dict[layer_name] + global_layer_offset = self.layer_number - 1 # self.layer_number starts at 1 + layer_key = f'{prefix}{global_layer_offset - offset}.{layer_name}' # module list index in TransformerBlock + sharded_offsets = [(0, global_layer_offset, num_layers)] # PP sharding + + if layer_name in tensor_parallel_layers_axis_map: + tp_axis = tensor_parallel_layers_axis_map[layer_name] + # TP sharding + sharded_offsets.append( + [ + tp_axis + 1, # +1 for PP dimension + parallel_state.get_tensor_model_parallel_rank(), + parallel_state.get_tensor_model_parallel_world_size(), + ] + ) + replica_id = parallel_state.get_data_parallel_rank() + else: + replica_id = ( + parallel_state.get_data_parallel_rank() * parallel_state.get_data_parallel_world_size() + + parallel_state.get_tensor_model_parallel_rank() + ) + + if layer_name.endswith('._extra_state'): + sharded_state_dict[layer_key] = ShardedObject( + f'{prefix}{layer_name}', tensor, (num_layers,), (global_layer_offset,), replica_id, + ) + + else: + sharded_state_dict[layer_key] = ShardedTensor.from_rank_offsets( + f'{prefix}{layer_name}', + tensor, + *sharded_offsets, + replica_id=replica_id, + prepend_axis_num=1, # for PP sharding + ) + + return sharded_state_dict diff --git a/nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_spec.py b/nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_spec.py new file mode 100644 index 000000000000..924e5f4321e6 --- /dev/null +++ b/nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_spec.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults + +try: + from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add + from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules + from megatron.core.transformer.custom_layers.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TENorm, + TERowParallelLinear, + ) + from megatron.core.transformer.enums import AttnMaskType + from megatron.core.transformer.mlp import MLP, MLPSubmodules + from megatron.core.transformer.spec_utils import ModuleSpec + + from megatron.core.transformer.transformer_layer import TransformerLayerSubmodules + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + + ModuleSpec = ApexGuardDefaults + +from .falcon_decoder_layer import FalconTransformerLayer + +# Use this spec for an implementation using modules in TE +def get_falcon_layer_spec() -> ModuleSpec: + if not HAVE_MEGATRON_CORE: + raise ImportError( + "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + ) + falcon_submodules = TransformerLayerSubmodules( + input_layernorm=TENorm, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=TENorm, + mlp=ModuleSpec( + module=MLP, submodules=MLPSubmodules(linear_fc1=TEColumnParallelLinear, linear_fc2=TERowParallelLinear,), + ), + mlp_bda=get_bias_dropout_add, + ) + # Old falcon(prior to 7b/40b/180b) uses post_self_attn_layernorm that is not included in TransformerLayerModules. + falcon_submodules.post_self_attn_layernorm = TENorm + return ModuleSpec(module=FalconTransformerLayer, submodules=falcon_submodules) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 58ed5f458e6f..37a237449f25 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -31,6 +31,7 @@ MegatronPretrainingSampler, ) from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import build_train_valid_test_datasets +from nemo.collections.nlp.models.language_modeling.megatron.falcon.falcon_spec import get_falcon_layer_spec from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import GPTModel from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel from nemo.collections.nlp.modules.common.megatron.build_model import build_model @@ -103,6 +104,13 @@ HAVE_TE = False +def get_specs(spec_name): + name_spec_dict = {"": get_gpt_layer_with_transformer_engine_spec(), "megatron_falcon_gpt": get_falcon_layer_spec()} + if spec_name not in name_spec_dict: + raise ValueError(f"Spec name '{spec_name}' is not recognized.") + return name_spec_dict[spec_name] + + class MegatronGPTExportableModel(torch.nn.Module, Exportable): """ Megatron GPT Wrapper for ONNX export @@ -215,6 +223,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False) self.mcore_gpt = cfg.get('mcore_gpt', False) + self.spec_name = cfg.get('name', '') self.rampup_batch_size = self.cfg.get('rampup_batch_size', None) if self.rampup_batch_size: @@ -297,7 +306,7 @@ def model_provider_func(self, pre_process, post_process): if self.mcore_gpt: model = MCoreGPTModel( config=self.transformer_config, - transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec(), + transformer_layer_spec=get_specs(self.spec_name), vocab_size=self.cfg.get('override_vocab_size', self.padded_vocab_size), max_sequence_length=self.cfg.get('encoder_seq_length', 512), pre_process=pre_process, @@ -1514,4 +1523,9 @@ def build_transformer_config(self) -> TransformerConfig: for key, value in model_specific_configs.items(): setattr(transformer_config, key, value) + # pass mcore customization configs directly to mcore + mcore_customization_config_dict = self.cfg.get('mcore_customization_config', {}) + for key, value in mcore_customization_config_dict.items(): + setattr(transformer_config, key, value) + return transformer_config diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py index 603d388312f4..eb86c8324dcd 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py @@ -18,6 +18,10 @@ from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.transformer.attention import SelfAttention +from megatron.core.transformer.custom_layers.transformer_engine import ( + TEColumnParallelLinear, + TELayerNormColumnParallelLinear, +) from megatron.core.transformer.mlp import MLP from megatron.core.transformer.transformer_layer import TransformerLayer from megatron.core.utils import make_viewless_tensor @@ -63,13 +67,32 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None): Derives `query`, `key` and `value` tensors from `hidden_states`. """ # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] - (mixed_qkv, layernorm_output), _ = self.linear_qkv(hidden_states) + linear_qkv_output, _ = self.linear_qkv(hidden_states) + layernorm_output = None + + # In megatron/core/models/gpt/gpt_layer_specs.py TELayerNormColumnParallelLinear is used for linear_qkv. + # TELayerNormColumnParallelLinear fused LN and linear, both will be returned. + # In nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_spec.py TEColumnParallelLinear is used for linear_qkv, + # which only returns linear. + if isinstance(self.linear_qkv, TELayerNormColumnParallelLinear): + mixed_qkv, layernorm_output = linear_qkv_output + elif isinstance(self.linear_qkv, TEColumnParallelLinear): # only mixed_qkv + mixed_qkv = linear_qkv_output + else: + raise ValueError( + f"Unrecognized module type '{type(self.linear_qkv)}' when getting query, key, value tensors for mcore mixins. " + ) # LoRA logic if self.is_adapter_available(): lora_kqv_adapter = self.get_adapter_module(AdapterName.LORA_KQV_ADAPTER) if lora_kqv_adapter: - lora_mixed_qkv = lora_kqv_adapter(layernorm_output) + if isinstance(self.linear_qkv, TELayerNormColumnParallelLinear): + lora_mixed_qkv = lora_kqv_adapter(layernorm_output) + elif isinstance(self.linear_qkv, TEColumnParallelLinear): + lora_mixed_qkv = lora_kqv_adapter(hidden_states) + else: + raise ValueError(f"Unrecognized module type '{type(self.linear_qkv)}' when applying lora.") mixed_qkv = mixed_qkv + lora_mixed_qkv # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] diff --git a/scripts/nlp_language_modeling/convert_hf_falcon_to_nemo.py b/scripts/nlp_language_modeling/convert_hf_falcon_to_nemo.py new file mode 100644 index 000000000000..ef9410b1b929 --- /dev/null +++ b/scripts/nlp_language_modeling/convert_hf_falcon_to_nemo.py @@ -0,0 +1,291 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Conversion script to convert Huggingface Falcon 1B/7B/40B/180B checkpoints into nemo checkpoint. + +This script will generate a Megatron model with TP=1 and PP=1. The new dist ckpt format does not require +user to run additional script to set the TP/PP values manually. + +Example to run this conversion script: +``` + python convert_hf_falcon_to_nemo.py \ + --config /path/to/megatron_falcon_config.yaml \ + --input /path/to/hf/checkpoints/folder \ + --output /path/to/output/nemo/file \ + --precision +``` +""" + +import argparse +import time +from typing import Dict + +import pytorch_lightning as pl +import torch +import yaml +from omegaconf import OmegaConf +from transformers import AutoModelForCausalLM, FalconConfig + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy +from nemo.utils import logging + + +def convert_state_dict(state_dict: Dict[str, torch.Tensor], amp: bool = False): + def get_new_key(old_key): + if old_key == "transformer.word_embeddings.weight": + return "embedding.word_embeddings.weight" + elif old_key.startswith("transformer.ln_f"): + return old_key.replace("transformer.ln_f", "decoder.final_layernorm") + elif old_key.startswith("lm_head"): + return old_key.replace("lm_head", "output_layer") + + # For the rest, a base transformation + key = old_key.replace("transformer.h", "decoder.layers") + + # Handling the layer normalization replacements + if falcon_config.new_decoder_architecture: + key = key.replace("ln_attn", "input_layernorm") + key = key.replace("ln_mlp", "pre_mlp_layernorm") + else: + key = key.replace("input_layernorm", "input_layernorm") + if not falcon_config.parallel_attn: + key = key.replace("post_attention_layernorm", "post_self_attn_layernorm") + + key = key.replace("self_attention.dense", "self_attention.linear_proj") + key = key.replace("self_attention.query_key_value", "self_attention.linear_qkv") + key = key.replace("dense_h_to_4h", "linear_fc1") + key = key.replace("dense_4h_to_h", "linear_fc2") + return key + + new_dict = {} + # amp O2 mode has different state dict name + prefix = "model.module." if amp else "model." + + for old_key, val in state_dict.items(): + new_key = get_new_key(old_key) + new_key = prefix + new_key + new_dict[new_key] = val + + return new_dict + + +def load_falcon_config(args) -> FalconConfig: + """ Helper utility to load FalconConfig. + + Legacy Falcon-7B and Falcon-40B are not compatible with `transformers.FalconConfig` and + `transformers.FalconModel`. need to manually set the config values + and force to `falcon` model type. + """ + config = FalconConfig.from_pretrained(args.input) + if config.model_type == 'RefinedWeb': + mappings = { + "num_hidden_layers": config.n_layer, + "num_attention_heads": config.n_head, + "num_kv_heads": config.n_head_kv, + "new_decoder_architecture": True, + } + elif config.model_type == 'RefinedWebModel': + mappings = { + "num_hidden_layers": config.n_layer, + "num_attention_heads": config.n_head, + "num_kv_heads": 1 if config.multi_query else config.n_head, + "new_decoder_architecture": False, + } + else: + return config + + for key, value in mappings.items(): + setattr(config, key, value) + + config.model_type = 'falcon' + return config + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True, help="Path to the megatron_gpt_config.yaml file") + parser.add_argument( + "--input", + type=str, + required=True, + help="Path to Falcon variants checkpoint from HuggingFace hub or local dir", + ) + parser.add_argument("--output", type=str, required=True, help="Path to dir where to store output .nemo file") + parser.add_argument( + "--precision", type=str, default="bf16", choices=["bf16", "32"], help="Precision for checkpoint weights saved" + ) + parser.add_argument("--cuda", action="store_true", help="Put Nemo model onto GPU prior to saving") + + args = parser.parse_args() + + falcon_config = load_falcon_config(args) + with open(args.config, "r", encoding="utf_8") as f: + orig_cfg = yaml.safe_load(f) + + model_dict = orig_cfg["model"] + + if "data" in model_dict: + del model_dict["data"] + + override_model_dict = { + "micro_batch_size": 1, + "global_batch_size": 1, + "tensor_model_parallel_size": 1, + "pipeline_model_parallel_size": 1, + "megatron_amp_O2": False, + "transformer_engine": True, + "use_cpu_initialization": not args.cuda, + "normalization": "layernorm", + "mcore_gpt": True, + "num_query_groups": None, # MHA + "hidden_size": falcon_config.hidden_size, + "encoder_seq_length": falcon_config.max_position_embeddings, + "max_position_embeddings": falcon_config.max_position_embeddings, + "num_layers": falcon_config.num_hidden_layers, + "num_attention_heads": falcon_config.num_attention_heads, + "ffn_hidden_size": falcon_config.hidden_size * 4, + "layernorm_epsilon": falcon_config.layer_norm_epsilon, + "pre_process": True, + "post_process": True, + "apply_query_key_layer_scaling": False, + "bias": falcon_config.bias, + "transformer_block_type": "pre_ln", + "fp32_residual_connection": False, + "hidden_dropout": falcon_config.hidden_dropout, + "attention_dropout": falcon_config.attention_dropout, + "ffn_dropout": 0, + "share_embeddings_and_output_weights": False, + "position_embedding_type": "rope", + "precision": args.precision, + "init_method_std": falcon_config.initializer_range, + "activation": "gelu", + "bias_activation_fusion": False, + "bias_dropout_add_fusion": False, + "seq_len_interpolation_factor": None, + } + + mcore_customization_config_dict = { + "new_decoder_architecture": falcon_config.new_decoder_architecture, + "parallel_attention": falcon_config.parallel_attn, + } + + tokenizer_dict = { + "library": "huggingface", + "type": args.input, + "use_fast": True, + } + trainer_dict = { + "devices": 1, + "num_nodes": 1, + "accelerator": "gpu" if args.cuda else "cpu", + "precision": args.precision, + "logger": False, + "enable_checkpointing": False, + "max_epochs": -1, + "max_steps": 100000, + "log_every_n_steps": 10, + "val_check_interval": 100, + "limit_val_batches": 50, + "limit_test_batches": 500, + "accumulate_grad_batches": 1, + "gradient_clip_val": 1.0, + "benchmark": False, + "enable_model_summary": False, + "strategy": NLPDDPStrategy(), + } + + # Additional logic for position_embedding_type = alibi + if falcon_config.alibi: + try: + raise ValueError( + "Alibi is not yet supported in Megatron Core, \ + force to use RoPE will generate suboptimal responses" + ) + except ValueError as e: + print(e) + + # Additional logic for num_query_groups + if override_model_dict.get("num_query_groups") is None: + if falcon_config.new_decoder_architecture: + override_model_dict["num_query_groups"] = falcon_config.num_kv_heads + elif falcon_config.multi_query: + override_model_dict["num_query_groups"] = 1 + + # Additional logic for bias fusion + if falcon_config.bias: + override_model_dict["bias_activation_fusion"] = True + override_model_dict["bias_dropout_add_fusion"] = True + + # Addtional logic for rope scaling + if falcon_config.rope_scaling is not None: + if falcon_config.rope_scaling.type == 'linear': + override_model_dict['seq_len_interpolation_factor'] = falcon_config.rope_scaling.factor + else: + raise ValueError("Only linear rope scaling type is supported now") + + model_dict.update(override_model_dict) + model_dict["tokenizer"] = tokenizer_dict + model_dict["name"] = 'megatron_falcon_gpt' + model_dict["mcore_customization_config"] = mcore_customization_config_dict + + omega_cfg = OmegaConf.create(model_dict) + + trainer = pl.Trainer(**trainer_dict) + + logging.info("Creating Megatron model...") + tik = time.time() + model = MegatronGPTModel(omega_cfg, trainer) + + logging.info("Loading HuggingFace model...") + model_hf = AutoModelForCausalLM.from_pretrained(args.input) + + state_dict_hf = model_hf.state_dict() + convert_dict = convert_state_dict(state_dict_hf, amp=omega_cfg.megatron_amp_O2) + + logging.info("Loading state dict...") + missing_keys, unexpected_keys = model.load_state_dict(convert_dict, strict=False) + + if missing_keys: + # Keys ending with '_extra_state' are related to Transformer Engine internals + missing_keys_non_extra = [key for key in missing_keys if not key.endswith("_extra_state")] + if missing_keys_non_extra: + logging.critical("Missing keys were detected during the load, something has gone wrong. Aborting.") + raise RuntimeError(f"Missing keys: \n{missing_keys_non_extra}") + + if unexpected_keys: + logging.critical("Unexpected keys were detected which should not happen. Aborting.") + raise RuntimeError(f"Unexpected keys: \n{unexpected_keys}") + + logging.info("Saving model...") + + # We make sure that the tokenizer can be instantiated later regardless of args.input + if falcon_config.new_decoder_architecture: + model.cfg.tokenizer.update(type="tiiuae/falcon-40b") + elif falcon_config.multi_query: + model.cfg.tokenizer.update(type="tiiuae/falcon-7b") + elif falcon_config.alibi and falcon_config.num_hidden_layers == 36: + model.cfg.tokenizer.update(type="tiiuae/falcon-rw-7b") + else: + model.cfg.tokenizer.update(type="tiiuae/falcon-rw-1b") + + dtype = torch.bfloat16 if args.precision == "bf16" else torch.float32 + model = model.to(dtype=dtype) + model.cfg.update(use_cpu_initialization=False) + model.save_to(args.output) + logging.info(f'Done. NeMo model saved to: {args.output}') + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + logging.info(f'nemo model created and saved. Total time: {t}') diff --git a/scripts/nlp_language_modeling/convert_nemo_falcon_to_hf.py b/scripts/nlp_language_modeling/convert_nemo_falcon_to_hf.py new file mode 100644 index 000000000000..66f6399855a3 --- /dev/null +++ b/scripts/nlp_language_modeling/convert_nemo_falcon_to_hf.py @@ -0,0 +1,171 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from argparse import ArgumentParser +from collections import OrderedDict + +import torch +from pytorch_lightning import Trainer +from transformers import AutoModelForCausalLM + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy +from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision +from nemo.utils import logging + +""" +Script to convert a falcon checkpoint in nemo (mcore path) into a HuggingFace checkpoint. +This script can be used to 1) generate only the HF weights, or 2) generate an entire HF model folder. + +1) Generate only HF weights from a nemo file: + + python convert_nemo_falcon_to_hf.py \ + --in-file /path/to/file.nemo or /path/to/extracted_folder \ + --out-file /path/to/pytorch_model.bin + +2) Generate the full HF model folder + + python convert_nemo_falcon_to_hf.py \ + --in-file /path/to/file.nemo or /path/to/extracted_folder \ + --out-file /path/to/pytorch_model.bin \ + --hf-in-file /path/to/input_hf_folder \ + --hf-out-file /path/to/output_hf_folder + + Use the --cpu-only flag if the model cannot fit in the GPU (e.g. falcon 180b). + However this option makes the conversion script significantly slower. +""" + + +def get_args(): + parser = ArgumentParser() + parser.add_argument( + "--in-file", type=str, required=True, help="Path to .nemo file", + ) + parser.add_argument("--out-file", type=str, required=True, help="Path to HF .bin file") + parser.add_argument( + "--hf-in-path", + type=str, + default=None, + help="A HF model path, " + "e.g. a folder containing https://huggingface.co/meta-falcon/falcon-2-7b-hf/tree/main", + ) + parser.add_argument( + "--hf-out-path", + type=str, + default=None, + help="Output HF model path, " "with the same format as above but user's own weights", + ) + parser.add_argument( + "--precision", + type=str, + default=None, + help="Precision of output weights." + "Defaults to precision of the input nemo weights (model.cfg.trainer.precision)", + ) + parser.add_argument( + "--cpu-only", + action="store_true", + help="Load model in cpu only. Useful if the model cannot fit in GPU memory, " + "but this option makes the conversion script significantly slower.", + ) + args = parser.parse_args() + return args + + +def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> None: + """ + Convert NeMo weights to HF weights + """ + dummy_trainer = Trainer(devices=1, accelerator='cpu', strategy=NLPDDPStrategy()) + if cpu_only: + map_location = torch.device('cpu') + model_config = MegatronGPTModel.restore_from(input_nemo_file, trainer=dummy_trainer, return_config=True) + model_config.use_cpu_initialization = True + model_config.tensor_model_parallel_size = 1 + else: + map_location, model_config = None, None + + if cpu_only: + logging.info("******** Loading model on CPU. This will take a significant amount of time.") + model = MegatronGPTModel.restore_from( + input_nemo_file, trainer=dummy_trainer, override_config_path=model_config, map_location=map_location + ) + if precision is None: + precision = model.cfg.precision + try: + dtype = torch_dtype_from_precision(precision) + except ValueError as e: + logging.warning(str(e) + f", precision string '{precision}' is not recognized, falling back to fp32") + dtype = torch.float32 # fallback + + param_to_weights = lambda param: param.to(dtype) + checkpoint = OrderedDict() + + def get_original_key(new_key): + new_key = new_key[len(prefix) :] + + if new_key.startswith("embedding.word_embeddings.weight"): + return "transformer.word_embeddings.weight" + elif new_key.startswith("decoder.final_layernorm"): + return new_key.replace("decoder.final_layernorm", "transformer.ln_f") + elif new_key.startswith("output_layer"): + return new_key.replace("output_layer", "lm_head") + + key = new_key.replace("decoder.layers", "transformer.h") + + if model.cfg.mcore_customization_config.new_decoder_architecture: + key = key.replace("input_layernorm", "ln_attn") + key = key.replace("pre_mlp_layernorm", "ln_mlp") + else: + key = key.replace("input_layernorm", "input_layernorm") + if not model.cfg.mcore_customization_config.parallel_attention: + key = key.replace("post_self_attn_layernorm", "post_attention_layernorm") + + key = key.replace("self_attention.linear_proj", "self_attention.dense") + key = key.replace("self_attention.linear_qkv", "self_attention.query_key_value") + key = key.replace("linear_fc1", "dense_h_to_4h") + key = key.replace("linear_fc2", "dense_4h_to_h") + return key + + prefix = 'model.module.' if any(k.startswith('model.module.') for k in model.state_dict()) else 'model.' + + for key, value in model.state_dict().items(): + if '_extra_state' in key: + continue + orig_key = get_original_key(key) + checkpoint[orig_key] = param_to_weights(value) + + os.makedirs(os.path.dirname(output_hf_file), exist_ok=True) + torch.save(checkpoint, output_hf_file) + logging.info(f"Weights reverted and saved to {output_hf_file}") + + +def replace_hf_weights(weights_file, input_hf_path, output_hf_path): + model = AutoModelForCausalLM.from_pretrained(input_hf_path, local_files_only=True) + nemo_exported = torch.load(weights_file) + + model.load_state_dict(nemo_exported) + model.save_pretrained(output_hf_path) + logging.info(f"Full HF model saved to {output_hf_path}") + + +if __name__ == '__main__': + args = get_args() + convert(args.in_file, args.out_file, precision=args.precision, cpu_only=args.cpu_only) + if args.hf_in_path and args.hf_out_path: + replace_hf_weights(args.out_file, args.hf_in_path, args.hf_out_path) + else: + logging.info("`hf-in-path` and/or `hf-out-path` not provided, not generating full HF model.") + logging.info(f".bin file is saved to {args.out_file}") diff --git a/tests/collections/nlp/test_falcon_model.py b/tests/collections/nlp/test_falcon_model.py new file mode 100644 index 000000000000..23430ad36300 --- /dev/null +++ b/tests/collections/nlp/test_falcon_model.py @@ -0,0 +1,262 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from omegaconf import DictConfig +from pytorch_lightning import Trainer + +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy + +DEVICE_CAPABILITY = None +if torch.cuda.is_available(): + DEVICE_CAPABILITY = torch.cuda.get_device_capability() + + +@pytest.fixture() +def model_cfg(test_data_dir): + + model_cfg = { + 'mcore_gpt': True, + 'micro_batch_size': 4, + 'global_batch_size': 8, + 'rampup_batch_size': None, + 'tensor_model_parallel_size': 1, + 'pipeline_model_parallel_size': 1, + 'virtual_pipeline_model_parallel_size': None, + 'encoder_seq_length': 512, + 'max_position_embeddings': 512, + 'num_layers': 1, + 'hidden_size': 128, + 'ffn_hidden_size': 512, + 'num_attention_heads': 2, + 'num_query_groups': 1, + 'init_method_std': 0.02, + 'use_scaled_init_method': True, + 'hidden_dropout': 0.0, + 'attention_dropout': 0.0, + 'ffn_dropout': 0, + 'kv_channels': None, + 'apply_query_key_layer_scaling': False, + 'normalization': 'layernorm', + 'layernorm_epsilon': 1e-05, + 'do_layer_norm_weight_decay': False, + 'make_vocab_size_divisible_by': 128, + 'pre_process': True, + 'post_process': True, + 'persist_layer_norm': True, + 'bias': False, + 'activation': 'gelu', + 'headscale': False, + 'transformer_block_type': 'pre_ln', + 'openai_gelu': False, + 'normalize_attention_scores': True, + 'position_embedding_type': 'rope', + 'rotary_percentage': 1.0, + 'attention_type': 'multihead', + 'share_embeddings_and_output_weights': False, + 'overlap_p2p_comm': False, + 'batch_p2p_comm': True, + 'seq_len_interpolation_factor': None, + 'tokenizer': {'library': 'huggingface', 'type': 'tiiuae/falcon-40b', 'use_fast': True}, + 'native_amp_init_scale': 4294967296, + 'native_amp_growth_interval': 1000, + 'hysteresis': 2, + 'fp32_residual_connection': False, + 'fp16_lm_cross_entropy': False, + 'megatron_amp_O2': False, + 'grad_allreduce_chunk_size_mb': 125, + 'grad_div_ar_fusion': True, + 'gradient_accumulation_fusion': False, + 'bias_activation_fusion': False, + 'bias_dropout_add_fusion': False, + 'masked_softmax_fusion': True, + 'get_attention_mask_from_fusion': True, + 'seed': 1234, + 'resume_from_checkpoint': None, + 'use_cpu_initialization': False, + 'onnx_safe': False, + 'apex_transformer_log_level': 30, + 'gradient_as_bucket_view': True, + 'sync_batch_comm': False, + 'activations_checkpoint_granularity': None, + 'activations_checkpoint_method': None, + 'activations_checkpoint_num_layers': None, + 'num_micro_batches_with_partial_activation_checkpoints': None, + 'activations_checkpoint_layers_per_pipeline': None, + 'sequence_parallel': False, + 'transformer_engine': True, + 'fp8': False, + 'fp8_e4m3': False, + 'fp8_hybrid': False, + 'fp8_margin': 0, + 'fp8_interval': 1, + 'fp8_amax_history_len': 1, + 'fp8_amax_compute_algo': 'most_recent', + 'reduce_amax': True, + 'use_emha': False, + 'ub_tp_comm_overlap': False, + 'ub_tp_comm_overlap_cfg': None, + 'use_flash_attention': False, + 'nsys_profile': {'enabled': False, 'start_step': 10, 'end_step': 10, 'ranks': [0], 'gen_shape': False}, + 'optim': { + 'name': 'distributed_fused_adam', + 'lr': '2e-4', + 'weight_decay': 0.01, + 'betas': [0.9, 0.98], + 'sched': {'name': 'CosineAnnealing', 'warmup_steps': 500, 'constant_steps': 50000, 'min_lr': '2e-5'}, + }, + 'gc_interval': 0, + 'precision': 'bf16', + 'new_decoder_architecture': False, + 'parallel_attention': True, + 'name': 'megatron_falcon_gpt', + 'target': 'nemo.collections.nlp.models.language_modeling.megatron_gpt_model.MegatronGPTModel', + } + return model_cfg + + +@pytest.fixture() +def trainer_cfg(): + + trainer_cfg = { + 'devices': 1, + 'num_nodes': 1, + 'accelerator': 'gpu', + 'precision': 'bf16', + 'logger': False, + 'enable_checkpointing': False, + 'use_distributed_sampler': False, + 'max_epochs': 1000, + 'max_steps': 100000, + 'log_every_n_steps': 10, + 'val_check_interval': 100, + 'limit_val_batches': 50, + 'limit_test_batches': 500, + 'accumulate_grad_batches': 1, + 'gradient_clip_val': 1.0, + } + + return trainer_cfg + + +@pytest.fixture() +def precision(): + return 'bf16' + + +@pytest.fixture() +def falcon_gpt_model(model_cfg, trainer_cfg, precision): + model_cfg['precision'] = precision + trainer_cfg['precision'] = precision + + strategy = NLPDDPStrategy() + + trainer = Trainer(strategy=strategy, **trainer_cfg) + + cfg = DictConfig(model_cfg) + + model = MegatronGPTModel(cfg=cfg, trainer=trainer) + + return model + + +@pytest.fixture() +def test_text(): + test_text = [ + "hello, world", + "four score and seven years ago", + "Your time is limited", + "If you set goals rediculously high", + ] + return test_text + + +@pytest.mark.run_only_on('GPU') +class TestFalconGPTModel: + @pytest.mark.unit + def test_constructor(self, falcon_gpt_model): + assert isinstance(falcon_gpt_model, MegatronGPTModel) + + num_weights = falcon_gpt_model.num_weights + assert num_weights == 16827136 + + @pytest.mark.unit + def test_tokenizer(self, falcon_gpt_model, test_text): + + assert isinstance(falcon_gpt_model.tokenizer, AutoTokenizer) + assert falcon_gpt_model.tokenizer.name == 'PreTrainedTokenizerFast' + assert falcon_gpt_model.tokenizer.vocab_size == 65024 + + ids = [falcon_gpt_model.tokenizer.text_to_ids(text) for text in test_text] + + true_ids = [ + [30835, 23, 1079], + [18584, 5179, 273, 5144, 909, 2323], + [4560, 601, 304, 3991], + [1424, 299, 889, 4258, 2400, 276, 20201, 986], + ] + assert sum([id_list == true_id_list for id_list, true_id_list in zip(ids, true_ids)]) == 4 + + @pytest.mark.parametrize( + "precision", + [ + 32, + 16, + pytest.param( + "bf16", + marks=pytest.mark.skipif( + not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, + reason='bfloat16 is not supported on this device', + ), + ), + ], + ) + @pytest.mark.unit + def test_forward(self, falcon_gpt_model, test_text): + + dtype = falcon_gpt_model.torch_dtype + + falcon_gpt_model.eval() + + ids = [falcon_gpt_model.tokenizer.text_to_ids(text) for text in test_text] + + id_tensors = [torch.unsqueeze(torch.LongTensor(id_list), dim=0) for id_list in ids] + + masks_and_position_ids = [ + get_ltor_masks_and_position_ids(id_tensor, falcon_gpt_model.tokenizer.eos_id, False, False, False) + for id_tensor in id_tensors + ] + output_tensors = [] + with torch.no_grad(): + for tokens, attn_mask_and_pos_ids in zip(id_tensors, masks_and_position_ids): + attn_mask, _, pos_ids = attn_mask_and_pos_ids + assert tokens.shape == pos_ids.shape + assert attn_mask.shape[2] == attn_mask.shape[3] == tokens.shape[1] == pos_ids.shape[1] + with torch.autocast('cuda', dtype=dtype): + output_tensor = falcon_gpt_model.forward( + tokens=tokens.cuda(), + text_position_ids=pos_ids.cuda(), + attention_mask=attn_mask.cuda(), + labels=None, + ) + # output is [b s h] + assert output_tensor.shape[0] == 1 + assert output_tensor.shape[1] == tokens.shape[1] + assert output_tensor.shape[2] == falcon_gpt_model.padded_vocab_size + assert output_tensor.dtype == dtype + output_tensors.append(output_tensor) diff --git a/tests/collections/nlp/test_falcon_transformer_layer.py b/tests/collections/nlp/test_falcon_transformer_layer.py new file mode 100644 index 000000000000..3edb541e8e33 --- /dev/null +++ b/tests/collections/nlp/test_falcon_transformer_layer.py @@ -0,0 +1,69 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +try: + + from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed + from megatron.core.transformer.transformer_config import TransformerConfig + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + +from nemo.collections.nlp.models.language_modeling.megatron.falcon.falcon_decoder_layer import FalconTransformerLayer +from nemo.collections.nlp.models.language_modeling.megatron.falcon.falcon_spec import get_falcon_layer_spec + + +@pytest.mark.run_only_on('GPU') +class TestParallelFalconTransformerLayer: + def setup_method(self, method): + model_parallel_cuda_manual_seed(123) + transformer_config = TransformerConfig( + num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True + ) + self.parallel_falcon_transformer_layer = FalconTransformerLayer( + transformer_config, get_falcon_layer_spec().submodules + ) + + @pytest.mark.unit + def test_constructor(self): + parallel_falcon_transformer_layer = self.parallel_falcon_transformer_layer + assert isinstance(parallel_falcon_transformer_layer, FalconTransformerLayer) + assert parallel_falcon_transformer_layer.layer_number == 1 + + num_weights = sum([p.numel() for p in parallel_falcon_transformer_layer.parameters()]) + assert num_weights == 1884 + + @pytest.mark.unit + def test_gpu_forward(self): + parallel_transformer_layer = self.parallel_falcon_transformer_layer + config: TransformerConfig = parallel_transformer_layer.config + sequence_length = 32 + micro_batch_size = 2 + parallel_transformer_layer.cuda() + + # [sequence length, batch size, hidden size] + hidden_states = torch.ones((sequence_length, micro_batch_size, config.hidden_size)) + hidden_states = hidden_states.cuda() + + attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda() + hidden_states, context = parallel_transformer_layer(hidden_states=hidden_states, attention_mask=attention_mask) + assert hidden_states.shape[0] == sequence_length + assert hidden_states.shape[1] == micro_batch_size + assert hidden_states.shape[2] == config.hidden_size