Skip to content

Commit

Permalink
Merge pull request #40 from epfLLM/instruction_tuning
Browse files Browse the repository at this point in the history
Instruction tuning
  • Loading branch information
martinjaggi authored Sep 2, 2023
2 parents 15b051d + 1ac215d commit 02bb7f4
Show file tree
Hide file tree
Showing 17 changed files with 935 additions and 99 deletions.
16 changes: 12 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@ This library enables pre-training and fine-tuning of large language models (LLMs
Our repository is a modification of the [original Megatron-LM codebase](https://github.com/NVIDIA/Megatron-LM) by Nvidia.

Added key features include:
- [Llama](https://arxiv.org/abs/2302.13971), [Llama 2](https://arxiv.org/abs/2307.09288) and [Falcon](https://huggingface.co/tiiuae) support
- support training of large models (70B Llama2, 65B Llama1 and 40B Falcon) on commodity hardware on multiple nodes
- [Llama](https://arxiv.org/abs/2302.13971), [Llama 2](https://arxiv.org/abs/2307.09288), [Code Llama](https://arxiv.org/abs/2308.12950) and [Falcon](https://huggingface.co/tiiuae) support
- support training of large models (70B Llama2, 65B Llama1, 34B Code Llama, and 40B Falcon) on commodity hardware on multiple nodes
- 3-way parallelism: tensor parallel, pipeline parallel and data parallel training (inherited from Megatron)
- pretraining and instruct tuning support
- grouped-query attention (GQA) and multi-query attention (MQA)
- Rotary Position Embeddings (RoPE) [was added independently by the Megatron project subsequent to us]
- RMS layer norm
- Rotary Position Embeddings (RoPE), RMS layer norm, Lima dropout
- RoPE scaling for longer attention context support
- FlashAttention 2
- BF16 / FP16 training
- Support for special tokens & tokenizers
- Conversion to and from Hugging Face
- WandB integration

# Documentation
Expand All @@ -31,6 +33,11 @@ pip install -r requirements.txt
make html
```

# Example models trained with *Megatron-LLM*
70B Llama 2 [1](https://huggingface.co/OpenAssistant/llama2-70b-oasst-sft-v10),
40B Falcon [1](https://huggingface.co/OpenAssistant/falcon-40b-megacode2-oasst),
13B Code Llama [1](https://huggingface.co/OpenAssistant/codellama-13b-oasst-sft-v10), ...
(Let us know about yours!)

# Citation

Expand All @@ -39,6 +46,7 @@ If you use this software please cite it:
@software{epfmgtrn,
author = {Alejandro Hernández Cano and
Matteo Pagliardini and
Andreas Köpf and
Kyle Matoba and
Amirkeivan Mohtashami and
Olivia Simin Fan and
Expand Down
5 changes: 3 additions & 2 deletions docs/guide/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,15 @@ python weights2megatron/weights2megatron.py llama2 --size=7 \
## Correctness verification (optional)
To make sure the weight conversion ran successfully we run the `verify_correctness.py` script.
This will run simultaneously the official Falcon implementation and the Megatron codebase.
This will run simultaneously the official LLaMa 2 implementation and the Megatron codebase.
Make sure to adjust the arguments to your convenience:
```bash
# arguments required by `torchrun`
DISTRIBUTED_ARGS="--nproc_per_node 1 --nnodes 1 --node_rank 0 --master_addr localhost --master_port 8000"
LLAMA_ARGS="--use_rms_norm --glu_activation swiglu --no_tie_embed_logits --no_new_tokens --layernorm_epsilon 1e-5"
COMMON_ARGS="--hidden_dropout 0.0 --attention_dropout 0.0 --no_bias_gelu_fusion"
torchrun $DISTRIBUTED_ARGS verify_correctness.py \
--model_name=falcon \
--model_name=llama2 \
--model_size=7 \
--load=/path/to/megatron/weights/ \
--data_path=/path/to/tokenized/starcoder \
Expand Down Expand Up @@ -266,4 +266,5 @@ for sequence in sequences:
- `examples/finetune.sh`
- `examples/verify.sh`
1. Take a look at `weights2megatron/README.md` and `tokenize-utils/README.md` for more information.
1. See the [intruction finetuning](instruction_tuning) guide for more information on how to finetune your pretrained model to follow instructions.
1. Take a look at our [FAQ](faq) section.
1 change: 1 addition & 0 deletions docs/guide/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
```{toctree}
getting_started
instruction_tuning
faq
```
92 changes: 92 additions & 0 deletions docs/guide/instruction_tuning.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Instruction finetuning

This tutorial will guide you through the basics of instruction finetuning using the Megatron-LLM codebase, using LLaMa 2 as the base network.
See also the [getting started](getting_started) guide for information regarding installation of dependencies, pretraining, and weight preparation.
Following said tutorial, you would be able to finetune a 7B model in this guide, but feel free to use a different size.
In order to use Falcon, see the comments specified in the [getting started](getting_started) guide to learn more about the differences when using either model.

## Preparing raw data

The dataset used in this guide will be a subset of the [orca](https://huggingface.co/datasets/Open-Orca/OpenOrca) dataset, a general purpose instruction dataset.
We choose to only include the chain of thought instructions from the orca dataset in order to shrink the size of the data.
Feel free to use any other dataset, as long as the raw data is saved in `.jsonl` format, i.e. one `json` dictionary per line.
The dictionaries must include at least two keys (one for the "instruction" and another one for the expected "answer"), plus an optional "system" key.
In order to retrieve the CoT subset of the orca dataset, use the following code:

```python
import json

from datasets import load_dataset

# the `cache_dir` is optional
dataset = load_dataset("Open-Orca/OpenOrca", cache_dir="/path/to/cache", split="train")
with open("/path/to/raw/data.jsonl", "w+") as f:
for document in tqdm(dataset):
if document["id"].startswith("cot."):
f.write(json.dumps(document) + "\n")
```

## Data preprocessing

In this step we will tokenize the raw data to binary files for optimized data loading during training.
Run:
```
python instruct/preprocess_instruct_data.py \
--input=/path/to/raw/data.jsonl \
--output_prefix=/path/to/tokenized/orca \
--tokenizer_type=SentencePieceTokenizer \
--vocab_file=/path/to/llama/tokenizer.model \
--chunk_size=32 \
--workers=32 \
--vocab_extra_ids_list "<|im_start|>,<|im_end|>" \
--question_key=question \
--answer_key=response \
--system_key=system_prompt # Optional
```

## Training

At this point, you should come up with a Megatron checkpoint ready to be trained (i.e. sharded with the desired parallelism levels).
Take a look at the [getting started](getting_started) guide to look how to transform LLaMa 2 checkpoints in the huggingface format to Megatron, and shard the weights.

To start training, use the `finetune.py`.
Example usage:
```bash
LOG_ARGS="--log_interval 1 --save_interval 100 --eval_interval 50"
TRAIN_ARGS="--train_iters 6500 --lr_decay_style cosine --lr_warmup_iters 650 --lr 2e-5 --min_lr 2e-6"
DISTRIBUTED_ARGS="--nproc_per_node NUMBER_OF_GPUS --nnodes 1 --node_rank 0 --master_addr localhost --master_port 8000"
torchrun $DISTRIBUTED_ARGS finetune.py \
--tensor_model_parallel_size 4 \
--pipeline_model_parallel_size 1 \
--load /path/to/sharded/weights/ \
--save /path/to/sharded/weights/ \
--tensorboard_dir /path/to/sharded/weights/tensorboard/ \
--data_path /path/to/tokenized/orca \
--model_name llama2 \
--tokenizer_type SentencePieceTokenizer \
--vocab_file=/path/to/megatron/weights/tokenizer.model \
--bf16 \
--use_flash_attn \
--micro_batch_size 8 \
--global_batch_size 64 \
--sequence_parallel \
--recompute_granularity selective \
--use_checkpoint_args \
--data_type instruction \
--variable_seq_lengths \
--vocab_extra_ids_list "<|im_start|>,<|im_end|>" \
$COMMON_ARGS $LOG_ARGS $TRAIN_ARGS $LLAMA_ARGS
```

The arguments given for pretraining and instruction finetuning are very similar, with the key differences being the batch sizes, learning rates, and the inclusion of `--data_type instruction`, `--variable_seq_lengths` and `--vocab_extra_ids_list`.
With the selected global batch size of 64, in 6500 iterations the trainer will perform approximately three epochs.
This will take approximately 3h hours to run on a 8x 80GB A100 device (DP=2, TP=4, PP=1).

```{note}
If your `--load` checkpoint corresponds to a checkpoint already trained with the Megatron-LLM codebase (and not a checkpoint gotten after directly converting from the huggingface format for instance), you might want to define a `--save` directory that points somewhere else, to avoid overwritting previous checkpoints.
You might also want to include the `--finetune` argument to ignore the previous optimizer and RNG states.
```

## Model Deployment

Once the finetuning is over, you can follow the [getting started](getting_started) guide steps to unshard your weights and convert them to huggingface, in order to do specific evaluations and deployment.
141 changes: 111 additions & 30 deletions examples/finetune.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,20 @@ RANK=0
N_NODES=1
ADDR=localhost
WANDB=0
INSTRUCT=0
CHECKPOINT_PATH=none
DATA=none
WANDB_PROJ=none
WANDB_ID=none
WANDB_ENTITY=none
ITERS=1000
SEQ_LEN=none
DATA_PATH=none
TRAINED_PATH=none
HELP_STR="[--rank=$RANK] [--size=$SIZE] [--tp=$TP] [--pp=$PP] [--gpus=$GPUS_PER_NODE] \
[--micro-batch=$MICRO_BATCH] [--global-batch=$GLOBAL_BATCH] [--nodes=$N_NODES] \
[--addr=$ADDR] [--wandb] [--help]"
[--addr=$ADDR] [--wandb] [--instruct] [--checkpoint=...] [--data=...] [--iters=$ITERS] \
[--wandb-proj=none] [--wandb-id=none] [--wandb-entity=none] [--seq-len=...] [--out=...] [--help]"


# define help function
Expand Down Expand Up @@ -48,67 +59,135 @@ while [[ $# -gt 0 ]]; do
--nodes) N_NODES=$2; shift; shift;;
--addr) ADDR=$2; shift; shift;;
--wandb) WANDB=1; shift;;
--wandb-project) WANDB_PROJ=$2; shift; shift;;
--wandb-id) WANDB_ID=$2; shift; shift;;
--wandb-entity) WANDB_ENTITY=$2; shift; shift;;
--instruct) INSTRUCT=1; shift;;
--checkpoint) CHECKPOINT_PATH=$2; shift; shift;;
--data) DATA_PATH=$2; shift; shift;;
--iters) ITERS=$2; shift; shift;;
--seq-len) SEQ_LEN=$2; shift; shift;;
--out) TRAINED_PATH=$2; shift; shift;;
*) echo unknown argument $1; help; exit 1;;
esac
done


# set args
LR="3e-4"
CHECKPOINT_PATH=/pure-mlo-scratch/alhernan/megatron-data/checkpoints/${MODEL}-${SIZE}b-tp$TP-pp$PP
TENSORBOARD_PATH=$CHECKPOINT_PATH-trained/logging
if [[ $CHECKPOINT_PATH = none ]]; then
CHECKPOINT_PATH=/pure-mlo-scratch/alhernan/megatron-data/checkpoints/${MODEL}-${SIZE}b-tp$TP-pp$PP
fi

if [[ $INSTRUCT = 1 ]]; then
LR="2e-5"
MIN_LR="2e-6"
if [[ $TRAINED_PATH = none ]]; then
TRAINED_PATH=$CHECKPOINT_PATH-instructed
fi
else
LR="3e-4"
MIN_LR="3e-4"
if [[ $TRAINED_PATH = none ]]; then
TRAINED_PATH=$CHECKPOINT_PATH-pretrained
fi
fi

TENSORBOARD_PATH=$TRAINED_PATH/logging
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $N_NODES --node_rank
$RANK --master_addr $ADDR --master_port 6000"

if [[ $MODEL = falcon ]]; then
DATA_PATH=/pure-mlo-scratch/pagliard/data/wikitext-falcon/wiki-train_text_document
if [[ $DATA_PATH = none ]]; then
DATA_PATH=/pure-mlo-scratch/pagliard/data/wikitext-falcon/wiki-train_text_document
fi
TOKENIZER=FalconTokenizer
EXTRA_ARGS="--parallel_attn"
SEQ_LEN=2048
if [[ $SEQ_LEN = none ]]; then
SEQ_LEN=2048
fi
elif [[ $MODEL = llama ]] || [[ $MODEL = llama2 ]] || [[ $MODEL = codellama ]]; then
DATA_PATH=/pure-mlo-scratch/trial-runs/test/pubmed-all-llama_text_document
TOKENIZER=SentencePieceTokenizer
EXTRA_ARGS='--vocab_file=/pure-mlo-scratch/llama/tokenizer.model --use_rms_norm
--glu_activation swiglu --no_tie_embed_logits
--vocab_extra_ids_list "[bib_ref],[/bib_ref],[fig_ref],[/fig_ref],[bib],[/bib],[fig],[/fig],[table],[/table],[formula],[/formula]"'
if [[ $MODEL = codellama ]]; then
EXTRA_ARGS="$EXTRA_ARGS --vocab_file=/pure-mlo-scratch/codellama/CodeLlama-7b/tokenizer.model --rope_theta 1e6"
EXTRA_IDS="[bib_ref],[/bib_ref],[fig_ref],[/fig_ref],[bib],[/bib],[fig],[/fig],[table],[/table],[formula],[/formula]"
EXTRA_ARGS="--use_rms_norm --glu_activation swiglu --no_tie_embed_logits"
if [[ $INSTRUCT = 1 ]]; then
if [[ $DATA_PATH = none ]]; then
DATA_PATH=/pure-mlo-scratch/alhernan/data/orca/orca
fi
EXTRA_IDS="$EXTRA_IDS,<|im_start|>,<|im_end|>"
else
EXTRA_ARGS="$EXTRA_ARGS --vocab_file=/pure-mlo-scratch/llama2/Llama-2-7b-hf/tokenizer.model"
if [[ $DATA_PATH = none ]]; then
DATA_PATH=/pure-mlo-scratch/data/tokenized/pubmed-all/pubmed-all-llama_text_document
fi
fi
TOKENIZER=SentencePieceTokenizer
EXTRA_ARGS="$EXTRA_ARGS --vocab_extra_ids_list $EXTRA_IDS"
if [[ $MODEL == llama ]]; then
SEQ_LEN=2048
if [[ $SEQ_LEN = none ]]; then
SEQ_LEN=2048
fi
EXTRA_ARGS="$EXTRA_ARGS --vocab_file=/pure-mlo-scratch/llama2/Llama-2-7b-hf/tokenizer.model"
EXTRA_ARGS="$EXTRA_ARGS --layernorm_epsilon 1e-6"
elif [[ $MODEL == llama2 ]]; # llama 2
SEQ_LEN=4096
elif [[ $MODEL == llama 2 ]];
if [[ $SEQ_LEN = none ]]; then
SEQ_LEN=4096
fi
EXTRA_ARGS="$EXTRA_ARGS --vocab_file=/pure-mlo-scratch/llama2/Llama-2-7b-hf/tokenizer.model"
EXTRA_ARGS="$EXTRA_ARGS --layernorm_epsilon 1e-5"
else # codellama
SEQ_LEN=16384
if (( $SIZE > 13 )); then # llama 2, 34B and 70B
LR="1.5e-4"
fi
else # codellama
if [[ $SEQ_LEN = none ]]; then
SEQ_LEN=16384
fi
EXTRA_ARGS="$EXTRA_ARGS --vocab_file=/pure-mlo-scratch/codellama/CodeLlama-7b/tokenizer.model --rope_theta 1e6"
fi
if (( $SIZE > 13 )); then # 34B and 70B
LR="1.5e-4"
fi
elif [[ $MODEL = gpt ]]; then
DATA_PATH=/scratch/wikitext-megatron/wikitext-train_text_document
if [[ $DATA_PATH = none ]]; then
DATA_PATH=/scratch/wikitext-megatron/wikitext-train_text_document
fi
TOKENIZER=FalconTokenizer
EXTRA_ARGS="--num_layers 4 --hidden_size 512 --num_attention_heads 8"
SEQ_LEN=2048
if [[ $SEQ_LEN = none ]]; then
SEQ_LEN=2048
fi
else
echo "Model should be either gpt, llama or falcon, not $MODEL"
help
exit 1
fi
COMMON_ARGS="--use_flash_attn --no_bias_gelu_fusion
--seq_length $SEQ_LEN --max_position_embeddings $SEQ_LEN
--log_interval 1 --save_interval 50 --eval_interval 50
--seq_length $SEQ_LEN --max_position_embeddings $SEQ_LEN
--log_interval 1 --save_interval 800 --eval_interval 200
--eval_iters 10 --hidden_dropout 0.0 --position_embedding_type rotary
--no_bias_dropout_fusion --use_checkpoint_args --train_iters 10000
--attention_dropout 0.0 --adam_beta1 0.9 --adam_beta2 0.95 --adam_eps 1e-5
--lr_decay_style cosine --lr_warmup_iters 2000 --lr $LR --min_lr 1e-6
--weight_decay 0.1 --sequence_parallel --recompute_granularity selective
--no_bias_dropout_fusion --use_checkpoint_args
--attention_dropout 0.0 --adam_beta1 0.9 --adam_beta2 0.95 --adam_eps 1e-5
--lr_decay_style cosine --lr_warmup_fraction 0.1 --lr $LR --min_lr $MIN_LR
--weight_decay 0.1 --sequence_parallel --recompute_granularity selective"
--log_timers_to_tensorboard --rope_scaling_factor 1.0"
if [[ $INSTRUCT = 1 ]]; then
COMMON_ARGS="$COMMON_ARGS --variable_seq_lengths --data_type instruction"
if [[ $CHECKPOINT_PATH != $TRAINED_PATH ]]; then
COMMON_ARGS="$COMMON_ARGS --finetune"
fi
fi
if [[ $CHECKPOINT_PATH != $TRAINED_PATH ]]; then
COMMON_ARGS="$COMMON_ARGS --train_iters $ITERS"
fi
if [[ $WANDB = 1 ]]; then
COMMON_ARGS="$COMMON_ARGS --wandb_logger"
if [[ $WANDB_PROJ != none ]]; then
COMMON_ARGS="$COMMON_ARGS --wandb_project $WANDB_PROJ"
fi
if [[ $WANDB_ID != none ]]; then
COMMON_ARGS="$COMMON_ARGS --wandb_id $WANDB_ID"
fi
if [[ $WANDB_ENTITY != none ]]; then
COMMON_ARGS="$COMMON_ARGS --wandb_entity $WANDB_ENTITY"
fi
fi
# print some args
Expand All @@ -119,11 +198,13 @@ echo ADDR=$ADDR
echo N_NODES=$N_NODES
echo DATA_PATH=$DATA_PATH
echo CHECKPOINT_PATH=$CHECKPOINT_PATH
echo TRAINED_PATH=$TRAINED_PATH
echo MODEL=$MODEL
echo TP=$TP
echo PP=$PP
echo MICRO_BATCH=$MICRO_BATCH
echo GLOBAL_BATCH=$GLOBAL_BATCH
echo INSTRUCT=$INSTRUCT
echo COMMON_ARGS=$COMMON_ARGS
echo EXTRA_ARGS=$EXTRA_ARGS
echo
Expand All @@ -134,7 +215,7 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 OMP_NUM_THREADS=16 torchrun $DISTRIBUTED_ARGS fine
--tensor_model_parallel_size $TP \
--pipeline_model_parallel_size $PP \
--load $CHECKPOINT_PATH \
--save $CHECKPOINT_PATH-trained \
--save $TRAINED_PATH \
--tensorboard_dir $TENSORBOARD_PATH \
--data_path $DATA_PATH \
--model_name $MODEL \
Expand Down
Loading

0 comments on commit 02bb7f4

Please sign in to comment.