diff --git a/megatron/arguments.py b/megatron/arguments.py index 3a71f5a1c..3835327e0 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -648,6 +648,11 @@ def _add_data_args(parser): '1) a single data path, 2) multiple datasets in the' 'form: dataset1-weight dataset1-path dataset2-weight ' 'dataset2-path ...') + group.add_argument('--valid-data-path', nargs='*', default=None, + help='Path to the validation dataset. If not provided,' + 'data will be selected from the --data-path based on --split.' + 'Accepted format : dataset1-weight dataset1-path ' + 'dataset2-weight dataset2-path ...') group.add_argument('--split', type=str, default='969, 30, 1', help='Comma-separated list of proportions for training,' ' validation, and test split. For example the split ' diff --git a/megatron/data/gpt_dataset.py b/megatron/data/gpt_dataset.py index e605c216e..32c1f4184 100644 --- a/megatron/data/gpt_dataset.py +++ b/megatron/data/gpt_dataset.py @@ -20,6 +20,7 @@ import numpy as np import torch +from collections import OrderedDict from megatron import mpu, print_rank_0 from megatron.data.blendable_dataset import BlendableDataset @@ -30,7 +31,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, train_valid_test_num_samples, - seq_length, seed, skip_warmup): + seq_length, seed, skip_warmup, + valid_data_prefix=None): """Build train, valid, and test datasets.""" # Single dataset. @@ -48,27 +50,63 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, # Build individual datasets. train_datasets = [] + # we'll temporarily store the validation sets then compare them with the arguments next step + # this needs to be ordered so it lines up with the weights + ds_shared_with_train = OrderedDict() valid_datasets = [] test_datasets = [] - for i in range(len(prefixes)): + for i, prefix in enumerate(prefixes): train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( - prefixes[i], data_impl, splits_string, + prefix, data_impl, splits_string, datasets_train_valid_test_num_samples[i], seq_length, seed, skip_warmup) + print_rank_0(f"split: {splits_string}") if train_ds: train_datasets.append(train_ds) + print_rank_0(f"train_ds size: {len(train_ds)}") if valid_ds: - valid_datasets.append(valid_ds) + # a safe split without overlap with training that we may use later if `valid_data_prefix` includes it + ds_shared_with_train[prefix] = valid_ds + print_rank_0(f"valid_ds size: {len(valid_ds)}") if test_ds: test_datasets.append(test_ds) + if valid_data_prefix is not None: + # in this case `valid_data_prefix` defines what is in the validation mix + # we make sure there is no overlap with the training mix by comparing with ds_shared_with_train + valid_output = get_datasets_weights_and_num_samples(valid_data_prefix, + [0, train_valid_test_num_samples[1], 0]) + valid_prefixes, valid_weights, valid_datasets_samples = valid_output + for i, prefix in enumerate(valid_prefixes): + if prefix not in ds_shared_with_train: + print_rank_0(f"prefix: {prefix} not found in {ds_shared_with_train.keys()}") + # those don't have overlap, we can fully add them + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + valid_prefixes[i], data_impl, '0,100,0', + valid_datasets_samples[i], + seq_length, seed, skip_warmup) + if valid_ds: + valid_datasets.append(valid_ds) + else: + # those have some overlap, so we use the safe split that we created earlier + print_rank_0(f"prefix: {prefix} found in {ds_shared_with_train.keys()}") + valid_datasets.append(ds_shared_with_train[prefix]) + else: + # in this case we assume that the user wants to use the mix of all the validation splits built beforehand. + valid_weights = weights + valid_datasets = ds_shared_with_train.values() + + print_rank_0(f"valid weights: {valid_weights}") + print_rank_0(f"size of validation sets: {[len(dataset) for dataset in valid_datasets]}") + print_rank_0(f"size of training sets: {[len(dataset) for dataset in train_datasets]}") + # Blend. blending_train_dataset = None if train_datasets: blending_train_dataset = BlendableDataset(train_datasets, weights) blending_valid_dataset = None if valid_datasets: - blending_valid_dataset = BlendableDataset(valid_datasets, weights) + blending_valid_dataset = BlendableDataset(valid_datasets, valid_weights) blending_test_dataset = None if test_datasets: blending_test_dataset = BlendableDataset(test_datasets, weights) @@ -89,6 +127,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, total_num_of_documents = indexed_dataset.sizes.shape[0] splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + print_rank_0(f"splits: {splits}") # Print stats about the splits. print_rank_0(' > dataset split:') diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 0137cad5e..5c82b44c2 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -191,7 +191,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): train_valid_test_num_samples=train_val_test_num_samples, seq_length=args.seq_length, seed=args.seed, - skip_warmup=(not args.mmap_warmup)) + skip_warmup=(not args.mmap_warmup), + valid_data_prefix=args.valid_data_path) print_rank_0("> finished creating GPT datasets ...") return train_ds, valid_ds, test_ds diff --git a/scripts/test_multiple_dataset_sampling/test_sampling.sh b/scripts/test_multiple_dataset_sampling/test_sampling.sh index 8c39ba1e4..a2a964829 100644 --- a/scripts/test_multiple_dataset_sampling/test_sampling.sh +++ b/scripts/test_multiple_dataset_sampling/test_sampling.sh @@ -85,17 +85,16 @@ ZERO_STAGE=0 #GLOBAL_BATCH=128 #WORKER_STR="-i worker-0" - -# 52B -TP=4 -PP=16 -HIDDEN=1024 -LAYERS=24 +#super small model +TP=1 +PP=1 +HIDDEN=256 +LAYERS=2 SEQ=128 -GLOBAL_BATCH=16 +GLOBAL_BATCH=4 WORKER_STR="" -MICRO_BATCH=8 +MICRO_BATCH=4 while [[ $# -gt 0 ]] do diff --git a/scripts/test_multiple_dataset_sampling/test_valid_sampling.sh b/scripts/test_multiple_dataset_sampling/test_valid_sampling.sh new file mode 100644 index 000000000..0816d55e4 --- /dev/null +++ b/scripts/test_multiple_dataset_sampling/test_valid_sampling.sh @@ -0,0 +1,184 @@ +EXP_PATH="./dumped/test/" +mkdir -p $EXP_PATH +BASE_DATA_PATH=$EXP_PATH +INPUT_PATH=$EXP_PATH +OUTPUT_PATH=$EXP_PATH + +wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json -P ${BASE_DATA_PATH} +wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt -P ${BASE_DATA_PATH} + +python scripts/test_multiple_dataset_sampling/create_dummy_dataset.py --dir ${INPUT_PATH} + + +python tools/preprocess_data.py \ + --input ${INPUT_PATH}/dataset_0.json \ + --output-prefix ${OUTPUT_PATH}/dataset-0 \ + --vocab ${BASE_DATA_PATH}/gpt2-vocab.json \ + --dataset-impl mmap \ + --tokenizer-type GPT2BPETokenizer \ + --merge-file ${BASE_DATA_PATH}/gpt2-merges.txt \ + --append-eod + +python tools/preprocess_data.py \ + --input ${INPUT_PATH}/dataset_1.json \ + --output-prefix ${OUTPUT_PATH}/dataset-1 \ + --vocab ${BASE_DATA_PATH}/gpt2-vocab.json \ + --dataset-impl mmap \ + --tokenizer-type GPT2BPETokenizer \ + --merge-file ${BASE_DATA_PATH}/gpt2-merges.txt \ + --append-eod + +python tools/preprocess_data.py \ + --input ${INPUT_PATH}/dataset_2.json \ + --output-prefix ${OUTPUT_PATH}/dataset-2 \ + --vocab ${BASE_DATA_PATH}/gpt2-vocab.json \ + --dataset-impl mmap \ + --tokenizer-type GPT2BPETokenizer \ + --merge-file ${BASE_DATA_PATH}/gpt2-merges.txt \ + --append-eod + +python tools/preprocess_data.py \ + --input ${INPUT_PATH}/dataset_3.json \ + --output-prefix ${OUTPUT_PATH}/dataset-3 \ + --vocab ${BASE_DATA_PATH}/gpt2-vocab.json \ + --dataset-impl mmap \ + --tokenizer-type GPT2BPETokenizer \ + --merge-file ${BASE_DATA_PATH}/gpt2-merges.txt \ + --append-eod + +python tools/preprocess_data.py \ + --input ${INPUT_PATH}/dataset_4.json \ + --output-prefix ${OUTPUT_PATH}/dataset-4 \ + --vocab ${BASE_DATA_PATH}/gpt2-vocab.json \ + --dataset-impl mmap \ + --tokenizer-type GPT2BPETokenizer \ + --merge-file ${BASE_DATA_PATH}/gpt2-merges.txt \ + --append-eod + + +DIR=`pwd` +DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` +mkdir -p ${BASE_DATA_PATH}/logs + +DATASET_0="${OUTPUT_PATH}/dataset-0_text_document" +DATASET_1="${OUTPUT_PATH}/dataset-1_text_document" +DATASET_2="${OUTPUT_PATH}/dataset-2_text_document" +DATASET_3="${OUTPUT_PATH}/dataset-3_text_document" +DATASET_4="${OUTPUT_PATH}/dataset-4_text_document" +DATASET="0.1 ${DATASET_0} 0.25 ${DATASET_1} 0.2 ${DATASET_2} 0.15 ${DATASET_3} 0.3 ${DATASET_4}" +VALID_DATASET="1.0 ${DATASET_0}" +VOCAB_PATH=${BASE_DATA_PATH}/gpt2-vocab.json +MERGE_PATH=${BASE_DATA_PATH}/gpt2-merges.txt + +CONFIG_JSON="${EXP_PATH}/ds_config.json" +touch $CONFIG_JSON + +USE_DEEPSPEED=1 +ZERO_STAGE=0 + +#super small model +TP=1 +PP=1 +HIDDEN=256 +LAYERS=2 +SEQ=128 +GLOBAL_BATCH=4 +WORKER_STR="" + +MICRO_BATCH=4 + +while [[ $# -gt 0 ]] +do +key="$1" +case $key in + --no-deepspeed) + USE_DEEPSPEED=0; + shift + ;; + -z|--zero-stage) + ZERO_STAGE=$2; + shift + ;; + *) + echo "Unknown argument(s)" + usage + exit 1 + shift + ;; +esac +done + +options=" \ + --tensor-model-parallel-size $TP \ + --pipeline-model-parallel-size $PP \ + --num-layers $LAYERS \ + --hidden-size $HIDDEN \ + --num-attention-heads 32 \ + --seq-length $SEQ \ + --loss-scale 12 \ + --max-position-embeddings $SEQ \ + --micro-batch-size $MICRO_BATCH \ + --global-batch-size $GLOBAL_BATCH \ + --train-iters 1000 \ + --lr 6.0e-5 \ + --min-lr 6.0e-6 \ + --lr-decay-style cosine \ + --log-interval 1 \ + --eval-iters 100 \ + --eval-interval 40 \ + --data-path ${DATASET} \ + --valid-data-path ${VALID_DATASET} \ + --vocab-file ${VOCAB_PATH} \ + --merge-file ${MERGE_PATH} \ + --save-interval 1000 \ + --split 98,2,0 \ + --clip-grad 1.0 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.006 \ + --fp16 \ + --checkpoint-activations + " + + +if [[ ${USE_DEEPSPEED} -eq 1 ]]; then + echo "Using DeepSpeed" + options="${options} \ + --deepspeed \ + --deepspeed_config=${CONFIG_JSON} \ + --zero-stage=${ZERO_STAGE} \ + --deepspeed-activation-checkpointing \ + " +fi + + +cat < $CONFIG_JSON +{ + "train_batch_size" : $GLOBAL_BATCH, + "train_micro_batch_size_per_gpu": $MICRO_BATCH, + "steps_per_print": 1, + "zero_optimization": { + "stage": $ZERO_STAGE + }, + "gradient_clipping": 1.0, + "prescale_gradients": true, + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 500, + "hysteresis": 2, + "min_loss_scale": 1, + "initial_scale_power": 12 + }, + "wall_clock_breakdown" : true +} +EOT + +# run_cmd="deepspeed $WORKER_STR ${DIR}/test_sampling.py $@ ${options}" +run_cmd="deepspeed $WORKER_STR pretrain_gpt.py $@ ${options}" + +echo ${run_cmd} +eval ${run_cmd} + +set +x \ No newline at end of file