Skip to content

Commit

Permalink
Support Falcon Variants (7B/40B/180B) in Mcore NeMo (#7666)
Browse files Browse the repository at this point in the history
* support falcon

* support falcon bug fix layernorm naming

* fix todo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix for new architecture

* new transformerlayer for falcon

* fix for new decoder architecture

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add DDP

* fix state dict based on spec system

* fix state dict based on change in layers, fix amp O2

* add falcon spec system support

* remove old falcon mcore support

* refactor conversion script to align with others

* add support for falcon-rw model (normal gpt architecture)

* modify falcon 7b config and remove trust remote code due to HF code changes

* rename falcon implementation dir

* change dir name

* modify block name

* rename decoder layer

* clean up

* remove debug

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add proper header

Signed-off-by: Vivian <[email protected]>

* falcon lora mixin to support when non-fused LN linear

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revise jenkinsfile, tokenizer update in convertion script, add two falcon config files

Signed-off-by: Vivian <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor falcon to use MCoreGPT+spec+baselayer initial commit

* modification to get nemo run with mcore in this version

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* small fix on the output file path

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add nemo to hf conversion script

* fix on base layer config and missing state dict due to dist ckpt

* Revert "fix on base layer config and missing state dict due to dist ckpt"

This reverts commit c85f3ac.

* fix on base layer config and missing state dict due to dist ckpt

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix megatron_gpt_model

Signed-off-by: Vivian chen <[email protected]>

* modify model config

Signed-off-by: Vivian chen <[email protected]>

* Apply suggestions from code review

Co-authored-by: Eric Harper <[email protected]>
Signed-off-by: Vivian Chen <[email protected]>

* fix based on review

Signed-off-by: Vivian Chen <[email protected]>

* multiple revise based on review and latest mcore changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

Signed-off-by: Vivian Chen <[email protected]>

* subclass from TransformerLayer

* fixes according to comments

* add falcon ci test

Signed-off-by: Vivian Chen <[email protected]>

* add post_self_attn_layernorm

* add explicit explanation/refs for handling lora logic

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes for code scanning

* remove unused imports

* unit test for falcon model

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add falcon transformer layer unit test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes for code scan

* remove mcore dependent tests

* Revert "remove mcore dependent tests"

This reverts commit 9c4960f.

* add import guards

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add import guards cont

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes for ci import tests and unit tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes for codeql

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert "fixes for codeql"

This reverts commit 9028555.

---------

Signed-off-by: Vivian <[email protected]>
Signed-off-by: Vivian chen <[email protected]>
Signed-off-by: Vivian Chen <[email protected]>
Signed-off-by: Vivian Chen <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Huiying Li <[email protected]>
Co-authored-by: HuiyingLi <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
  • Loading branch information
5 people committed Dec 15, 2023
1 parent 58a277a commit 8523384
Show file tree
Hide file tree
Showing 12 changed files with 1,400 additions and 3 deletions.
9 changes: 9 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,15 @@ pipeline {
sh 'rm -f /home/TestData/nlp/megatron_gpt/starcoder-ci-hf/megatron_starcoder_tp1_pp1.nemo'
}
}
stage('Falcon') {
steps {
sh 'python scripts/nlp_language_modeling/convert_hf_falcon_to_nemo.py \
--config examples/nlp/language_modeling/conf/megatron_falcon_config.yaml \
--input /home/TestData/nlp/megatron_gpt/falcon-ci-hf \
--output /home/TestData/nlp/megatron_gpt/falcon-ci-hf/falcon_ci.nemo'
sh 'rm -f /home/TestData/nlp/megatron_gpt/falcon-ci-hf/falcon_ci.nemo'
}
}
}
}

Expand Down
219 changes: 219 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_falcon_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
name: megatron_falcon_gpt
restore_from_path: null # used when starting from a .nemo file

trainer:
devices: 1
num_nodes: 1
accelerator: gpu
precision: bf16
logger: False # logger provided by exp_manager
enable_checkpointing: False
use_distributed_sampler: False
max_epochs: -1 # PTL default. In practice, max_steps will be reached first.
max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
val_check_interval: 100
limit_val_batches: 50
limit_test_batches: 500
accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models
gradient_clip_val: 1.0
benchmark: False
enable_model_summary: False # default PTL callback for this does not support model parallelism, instead we log manually

exp_manager:
explicit_log_dir: null
exp_dir: null
name: megatron_falcon_gpt
create_wandb_logger: False
wandb_logger_kwargs:
project: null
name: null
resume_if_exists: True
resume_ignore_no_checkpoint: True
create_checkpoint_callback: True
checkpoint_callback_params:
monitor: val_loss
save_top_k: 10
mode: min
always_save_nemo: False # saves nemo file during validation, not implemented for model parallel
save_nemo_on_train_end: False # not recommended when training large models on clusters with short time limits
filename: 'megatron_falcon--{val_loss:.2f}-{step}-{consumed_samples}'
model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}}

model:
mcore_gpt: True
# specify micro_batch_size, global_batch_size, and model parallelism
# gradient accumulation will be done automatically based on data_parallel_size
micro_batch_size: 1 # limited by GPU memory
global_batch_size: 1 # will use more micro batches to reach global batch size
tensor_model_parallel_size: 1 # intra-layer model parallelism
pipeline_model_parallel_size: 1 # inter-layer model parallelism
virtual_pipeline_model_parallel_size: null # interleaved pipeline

# model architecture
encoder_seq_length: 2048
max_position_embeddings: ${.encoder_seq_length}
num_layers: 32 # 7b: 32 | 40b: 60 | 180b: 80
hidden_size: 4544 # 7b: 4544 | 40b: 8192 | 180b: 14848
ffn_hidden_size: 18176 # Transformer FFN hidden size. Usually 4 * hidden_size. | 7b: 18176 | 40b: 32768 | 180b: 59392
num_attention_heads: 71 # 7b: 71 | 40b: 128 | 180b: 232
init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.')
use_scaled_init_method: True # use scaled residuals initialization
hidden_dropout: 0.0 # Dropout probability for hidden state transformer.
attention_dropout: 0.0 # Dropout probability for attention
ffn_dropout: 0.0 # Dropout probability in the feed-forward layer.
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number.
normalization: 'layernorm' # Normalization layer to use. Options are 'layernorm', 'rmsnorm'
layernorm_epsilon: 1e-5
do_layer_norm_weight_decay: False # True means weight decay on all params
make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency.
pre_process: True # add embedding
post_process: True # add pooler
persist_layer_norm: True # Use of persistent fused layer norm kernel.
bias: False # Whether to use bias terms in all weight matrices.
activation: 'gelu' # Options ['gelu', 'geglu', 'swiglu', 'reglu', 'squared-relu', 'fast-geglu', 'fast-swiglu', 'fast-reglu']
headscale: False # Whether to learn extra parameters that scale the output of the each self-attention head.
transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer']
openai_gelu: False # Use OpenAI's GELU instead of the default GeLU
normalize_attention_scores: True # Whether to scale the output Q * K^T by 1 / sqrt(hidden_size_per_head). This arg is provided as a configuration option mostly for compatibility with models that have been weight-converted from HF. You almost always want to se this to True.
position_embedding_type: 'rope' # Position embedding type. Options ['learned_absolute', 'rope']
rotary_percentage: 1.0 # If using position_embedding_type=rope, then the per head dim is multiplied by this.
attention_type: 'multihead' # Attention type. Options ['multihead']
share_embeddings_and_output_weights: False # Share embedding and output layer weights.
overlap_p2p_comm: False # Overlap p2p communication with computes. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1
batch_p2p_comm: True # Batch consecutive inter-peer send/recv operations. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1
num_query_groups: 1 # Number of query groups for group query attention. If None, normal attention is used. | 7b: 1 | 40b: 8 | 180b: 8
gc_interval: 0
precision: bf16
mcore_customization_config:
new_decoder_architecture: false
parallel_attention: true

tokenizer:
library: 'huggingface'
type: 'tiiuae/falcon-7b'
use_fast: True

# Mixed precision
native_amp_init_scale: 4294967296 # 2 ** 32
native_amp_growth_interval: 1000
hysteresis: 2 # Gradient scale hysteresis
fp32_residual_connection: False # Move residual connections to fp32
fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16

# Megatron O2-style half-precision
megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters
grad_allreduce_chunk_size_mb: 125

# Fusion
grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce. Only used with O2 and no pipeline parallelism..
gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism and O2.
bias_activation_fusion: False # Use a kernel that fuses the bias addition from weight matrices with the subsequent activation function.
bias_dropout_add_fusion: False # Use a kernel that fuses the bias addition, dropout and residual connection addition.
masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask.
get_attention_mask_from_fusion: True # When using fused softmax it will create the attention mask so we won't copy it to the pipeline stages.


# Miscellaneous
seed: 1234
resume_from_checkpoint: null # manually set the checkpoint file to load from
use_cpu_initialization: False # Init weights on the CPU (slow for large models)
onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter.
apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this
gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)
sync_batch_comm: False # Enable stream synchronization after each p2p communication between pipeline stages

## Activation Checkpointing
# NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed.
# These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+).
# See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
# 'full' will checkpoint the entire transformer layer.
activations_checkpoint_granularity: null # 'selective' or 'full'
activations_checkpoint_method: null # 'uniform', 'block'
# 'uniform' divides the total number of transformer layers and checkpoints the input activation
# of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model.
# 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
activations_checkpoint_num_layers: null
# when using 'uniform' this creates groups of transformer layers to checkpoint. Usually set to 1. Increase to save more memory.
# when using 'block' this this will checkpoint the first activations_checkpoint_num_layers per pipeline stage.
num_micro_batches_with_partial_activation_checkpoints: null
# This feature is valid only when used with pipeline-model-parallelism.
# When an integer value is provided, it sets the number of micro-batches where only a partial number of Transformer layers get checkpointed
# and recomputed within a window of micro-batches. The rest of micro-batches in the window checkpoint all Transformer layers. The size of window is
# set by the maximum outstanding micro-batch backpropagations, which varies at different pipeline stages. The number of partial layers to checkpoint
# per micro-batch is set by 'activations_checkpoint_num_layers' with 'activations_checkpoint_method' of 'block'.
# This feature enables using activation checkpoint at a fraction of micro-batches up to the point of full GPU memory usage.
activations_checkpoint_layers_per_pipeline: null
# This feature is valid only when used with pipeline-model-parallelism.
# When an integer value (rounded down when float is given) is provided, it sets the number of Transformer layers to skip checkpointing at later
# pipeline stages. For example, 'activations_checkpoint_layers_per_pipeline' of 3 makes pipeline stage 1 to checkpoint 3 layers less than
# stage 0 and stage 2 to checkpoint 6 layers less stage 0, and so on. This is possible because later pipeline stage
# uses less GPU memory with fewer outstanding micro-batch backpropagations. Used with 'num_micro_batches_with_partial_activation_checkpoints',
# this feature removes most of activation checkpoints at the last pipeline stage, which is the critical execution path.

## Sequence Parallelism
# Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially
# See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
sequence_parallel: False

## Transformer Engine
fp8: False # enables fp8 in TransformerLayer forward
fp8_e4m3: False # sets fp8_format = recipe.Format.E4M3
fp8_hybrid: False # sets fp8_format = recipe.Format.HYBRID
fp8_margin: 0 # scaling margin
fp8_interval: 1 # scaling update interval
fp8_amax_history_len: 1 # Number of steps for which amax history is recorded per tensor
fp8_amax_compute_algo: most_recent # 'most_recent' or 'max'. Algorithm for computing amax from history
reduce_amax: True # Perform reduction to sync amax tensors across GPUs after every iteration
use_emha: False # Use fused multi-head attention for large sequence-length. Note this is not yet supported. Please set to False.

data:
# Path to data must be specified by the user.
# Supports List, String and Dictionary
# List : can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-gpt3_00_text_document,.5,/raid/data/pile/my-gpt3_01_text_document]",
# Or see example below:
# data_prefix:
# - .5
# - /raid/data/pile/my-gpt3_00_text_document
# - .5
# - /raid/data/pile/my-gpt3_01_text_document
# Dictionary: can override from CLI "model.data.data_prefix"={"train":[1.0, /path/to/data], "validation":/path/to/data, "test":/path/to/test}
# Or see example below:
# "model.data.data_prefix: {train:[1.0,/path/to/data], validation:[/path/to/data], test:[/path/to/test]}"
# data_prefix: ???
index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix
data_impl: mmap
splits_string: 900,50,50
seq_length: ${model.encoder_seq_length}
skip_warmup: True
num_workers: 2
dataloader_type: single # cyclic
reset_position_ids: False # Reset position ids after end-of-document token
reset_attention_mask: False # Reset attention mask after end-of-document token
eod_mask_loss: False # Mask loss for the end of document tokens
validation_drop_last: True # Set to false if the last partial validation samples is to be consumed
no_seqlen_plus_one_input_tokens: False # Set to True to disable fetching (sequence length + 1) input tokens, instead get (sequence length) input tokens and mask the last token
pad_samples_to_global_batch_size: False # Set to True if you want to pad the last partial batch with -1's to equal global batch size
shuffle_documents: True # Set to False to disable documents shuffling. Sample index will still be shuffled

# Nsys profiling options
nsys_profile:
enabled: False
start_step: 10 # Global batch to start profiling
end_step: 10 # Global batch to end profiling
ranks: [0] # Global rank IDs to profile
gen_shape: False # Generate model and kernel details including input shapes

optim:
name: distributed_fused_adam
lr: 2e-4
weight_decay: 0.01
betas:
- 0.9
- 0.98
sched:
name: CosineAnnealing
warmup_steps: 500
constant_steps: 50000
min_lr: 2e-5
38 changes: 38 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_falcon_inference.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
inference:
greedy: False # Whether or not to use sampling ; use greedy decoding otherwise
top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature: 1.0 # sampling temperature
add_BOS: False # add the bos token at the begining of the prompt
tokens_to_generate: 30 # The minimum length of the sequence to be generated.
all_probs: False # whether return the log prob for all the tokens in vocab
repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty.
min_tokens_to_generate: 0 # The minimum length of the sequence to be generated.
compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False
end_strings: ["<|endoftext|>"] # generation will stop when one of these tokens is generated

trainer:
devices: 1
num_nodes: 1
accelerator: gpu
logger: False # logger provided by exp_manager
precision: bf16 # 16, 32, or bf16
use_distributed_sampler: False

tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
megatron_amp_O2: True # Enable O2-level automatic mixed precision to save memory
gpt_model_file: null # GPT nemo file path
checkpoint_dir: null # checkpoint file dir. This is used to load the PTL checkpoint generated during the GPT training
checkpoint_name: null # PTL checkpoint file name, only used for PTL checkpoint loading
hparams_file: null # model configuration file, only used for PTL checkpoint loading
prompts: # prompts for GPT inference
- "Q: How are you?"
- "Q: How big is the universe?"
server: False # whether launch the API server
port: 5555 # the port number for the inference server
web_server: False # whether launch the web inference server
share: False # whether create a public URL
username: test # user name for web client
password: test2 # password for web client
web_port: 9889 # the port number of the web server
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading

0 comments on commit 8523384

Please sign in to comment.