Skip to content
5 changes: 5 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,11 @@ def _add_data_args(parser):
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--valid-data-path', nargs='*', default=None,
help='Path to the validation dataset. If not provided,'
'data will be selected from the --data-path based on --split.'
'Accepted format : dataset1-weight dataset1-path '
'dataset2-weight dataset2-path ...')
group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
Expand Down
49 changes: 44 additions & 5 deletions megatron/data/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np
import torch
from collections import OrderedDict

from megatron import mpu, print_rank_0
from megatron.data.blendable_dataset import BlendableDataset
Expand All @@ -30,7 +31,8 @@

def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup):
seq_length, seed, skip_warmup,
valid_data_prefix=None):
"""Build train, valid, and test datasets."""

# Single dataset.
Expand All @@ -48,27 +50,63 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,

# Build individual datasets.
train_datasets = []
# we'll temporarily store the validation sets then compare them with the arguments next step
# this needs to be ordered so it lines up with the weights
ds_shared_with_train = OrderedDict()
valid_datasets = []
test_datasets = []
for i in range(len(prefixes)):
for i, prefix in enumerate(prefixes):
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
prefixes[i], data_impl, splits_string,
prefix, data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
seq_length, seed, skip_warmup)
print_rank_0(f"split: {splits_string}")
if train_ds:
train_datasets.append(train_ds)
print_rank_0(f"train_ds size: {len(train_ds)}")
if valid_ds:
valid_datasets.append(valid_ds)
# a safe split without overlap with training that we may use later if `valid_data_prefix` includes it
ds_shared_with_train[prefix] = valid_ds
print_rank_0(f"valid_ds size: {len(valid_ds)}")
if test_ds:
test_datasets.append(test_ds)

if valid_data_prefix is not None:
# in this case `valid_data_prefix` defines what is in the validation mix
# we make sure there is no overlap with the training mix by comparing with ds_shared_with_train
valid_output = get_datasets_weights_and_num_samples(valid_data_prefix,
[0, train_valid_test_num_samples[1], 0])
valid_prefixes, valid_weights, valid_datasets_samples = valid_output
for i, prefix in enumerate(valid_prefixes):
if prefix not in ds_shared_with_train:
print_rank_0(f"prefix: {prefix} not found in {ds_shared_with_train.keys()}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print_rank_0(f"prefix: {prefix} not found in {ds_shared_with_train.keys()}")
print_rank_0(f"prefix: {prefix} not found in {ds_shared_with_train.keys()}")

This looks more like a debug and won't be of any added value to the user, IMHO, and just add to the noise.

What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, I'll remove some of the prints

# those don't have overlap, we can fully add them
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
valid_prefixes[i], data_impl, '0,100,0',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the presence of this new flag require the user to pass the normal data-path split so that the 2nd split is 0, e.g. 100,0,0 or 90,0,10?

Will this make things simpler?

Copy link
Contributor

@stas00 stas00 Oct 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or I'd go even simpler and require no split argument then, if a separate validation data path is passed then do no splitting at all. i.e. data-path => Train, valid-data-path => Valid (and we can add test-data-path if needed).

Yet another possible solution:

Leave --data-path feature as is, and have --(train|valid|test)-data-path arg and make these mutually exclusive.

So it's one of the two sets:

  1. --data-path and --split pair
  2. --(train|valid|test)-data-path (3 args) and no --split and test is optional.

The originally proposed solution feels too much of a "patch".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something I feel we need with this feature is to deal with datasets that only have a train split and no validation, including OSCAR which is why I wrote it like this. In this case you still want to be able to act on your validation mix even though you don't have a separate file to point to: for example, you want to use only one of the validation splits you created in the preceding step rather the mix of all.

Copy link
Collaborator Author

@TevenLeScao TevenLeScao Oct 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you point to a file that's not already in the training set, e.g. an external validation set, then it feels natural to use it all. We can easily add another flag to split it but I cannot see a usecase for it. (edit: now I can see a usecase for it - making the validation split not too big - although you could probably fix that with the valid_weights)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand your correctly that's why I suggested to have 2 different sets of APIs:

  1. --data-path and --split pair - normal use
  2. --(train|valid|test)-data-path (3 args) and no --split and test is optional. For when you want to take full control over the splits.

I think this would make things easier to use.

but perhaps I'm not seeing some nuance here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wasn't bearing, the code needed to be clearer anyway. The weights of valid-data-path apply, which is clear since we're also using the data from valid-data-path!

I am not sure what you call a wrong use - in the OSCAR multilingual experiments, we 1. do not have a separate validation set 2. we need to restrict the validation to only a language or set of languages. This naturally leads to passing a path that overrides the original validation mix but also has overlap with it.

Copy link
Collaborator

@sbmaruf sbmaruf Oct 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @stas00

Why are you selecting validation data from training data? You can just create data in --valid-path and take it from there. To me it seems redundant.

The implementation was,

You only take validation data from training when you don't have any data path in--valid-data-path argument. When you are taking data from both --data-path and --valid-data-path, that's so confusing to me.

I understand this is needed for OSCAR. But I am against it because in this way, we cannot track which samples are selected for training and which one for development. IMO if oscar doesn't have development set, we should create our own split and share that split for re-producibility. This is also important because we are doing comparison with multiple datasets.

Another problem,

Let's assume a case,

TRAIN_DATASET=DATASET_0
VALID_DATASET="0.1 ${DATASET_0} 0.25 ${DATASET_1} 0.2 ${DATASET_2} 0.15 ${DATASET_3} 0.3 ${DATASET_4}"

when you run this with --split 80,20,0, in addition to validation data, 20% of training data with be sampled with a relative probability 0.1. For a normal user, this seems overly complicated.

@TevenLeScao

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is different from your implementation - as discussed above, this only takes data from --valid-data-path and not from --data-path. We can talk about the need for a separate development set for OSCAR but it is not the only dataset in this case as language modeling generally doesn't need validation sets. I'd rather give the user (i.e. us!) this flexilibity rather than forcing them to make their own splits.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the case you've shown does exactly what the user wants - it samples dataset_0 exactly how much they expect, while ensuring it picks the parts that don't overlap with training. If the user feels that is too complicated, they can also just split before the processing and not have overlapping data-paths and valid-data-paths, so there's no harm.

Copy link
Contributor

@stas00 stas00 Oct 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm glad to hear that it's not just me who finds the proposed API confusing and the fact that were are still trying to figure it out is an indication of that.

May I re-iterate a different proposal, I suggested yesterday,

where we have two mutually exclusive modes:

  1. --data-path and --split pair - normal use
  2. --(train|valid|test)-data-path (3 args) and no --split and test is optional. For when you want to take full control over the splits.

I think this would make things easier to use, since the 2nd approach is very explicit.

valid_datasets_samples[i],
seq_length, seed, skip_warmup)
if valid_ds:
valid_datasets.append(valid_ds)
else:
# those have some overlap, so we use the safe split that we created earlier
print_rank_0(f"prefix: {prefix} found in {ds_shared_with_train.keys()}")
valid_datasets.append(ds_shared_with_train[prefix])
else:
# in this case we assume that the user wants to use the mix of all the validation splits built beforehand.
valid_weights = weights
valid_datasets = ds_shared_with_train.values()

print_rank_0(f"valid weights: {valid_weights}")
print_rank_0(f"size of validation sets: {[len(dataset) for dataset in valid_datasets]}")
print_rank_0(f"size of training sets: {[len(dataset) for dataset in train_datasets]}")

# Blend.
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_valid_dataset = BlendableDataset(valid_datasets, valid_weights)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)
Expand All @@ -89,6 +127,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,

total_num_of_documents = indexed_dataset.sizes.shape[0]
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
print_rank_0(f"splits: {splits}")

# Print stats about the splits.
print_rank_0(' > dataset split:')
Expand Down
3 changes: 2 additions & 1 deletion pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup))
skip_warmup=(not args.mmap_warmup),
valid_data_prefix=args.valid_data_path)
print_rank_0("> finished creating GPT datasets ...")

return train_ds, valid_ds, test_ds
Expand Down
15 changes: 7 additions & 8 deletions scripts/test_multiple_dataset_sampling/test_sampling.sh
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,16 @@ ZERO_STAGE=0
#GLOBAL_BATCH=128
#WORKER_STR="-i worker-0"


# 52B
TP=4
PP=16
HIDDEN=1024
LAYERS=24
#super small model
TP=1
PP=1
HIDDEN=256
LAYERS=2
SEQ=128
GLOBAL_BATCH=16
GLOBAL_BATCH=4
WORKER_STR=""

MICRO_BATCH=8
MICRO_BATCH=4

while [[ $# -gt 0 ]]
do
Expand Down
184 changes: 184 additions & 0 deletions scripts/test_multiple_dataset_sampling/test_valid_sampling.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
EXP_PATH="./dumped/test/"
mkdir -p $EXP_PATH
BASE_DATA_PATH=$EXP_PATH
INPUT_PATH=$EXP_PATH
OUTPUT_PATH=$EXP_PATH

wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json -P ${BASE_DATA_PATH}
wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt -P ${BASE_DATA_PATH}

python scripts/test_multiple_dataset_sampling/create_dummy_dataset.py --dir ${INPUT_PATH}


python tools/preprocess_data.py \
--input ${INPUT_PATH}/dataset_0.json \
--output-prefix ${OUTPUT_PATH}/dataset-0 \
--vocab ${BASE_DATA_PATH}/gpt2-vocab.json \
--dataset-impl mmap \
--tokenizer-type GPT2BPETokenizer \
--merge-file ${BASE_DATA_PATH}/gpt2-merges.txt \
--append-eod

python tools/preprocess_data.py \
--input ${INPUT_PATH}/dataset_1.json \
--output-prefix ${OUTPUT_PATH}/dataset-1 \
--vocab ${BASE_DATA_PATH}/gpt2-vocab.json \
--dataset-impl mmap \
--tokenizer-type GPT2BPETokenizer \
--merge-file ${BASE_DATA_PATH}/gpt2-merges.txt \
--append-eod

python tools/preprocess_data.py \
--input ${INPUT_PATH}/dataset_2.json \
--output-prefix ${OUTPUT_PATH}/dataset-2 \
--vocab ${BASE_DATA_PATH}/gpt2-vocab.json \
--dataset-impl mmap \
--tokenizer-type GPT2BPETokenizer \
--merge-file ${BASE_DATA_PATH}/gpt2-merges.txt \
--append-eod

python tools/preprocess_data.py \
--input ${INPUT_PATH}/dataset_3.json \
--output-prefix ${OUTPUT_PATH}/dataset-3 \
--vocab ${BASE_DATA_PATH}/gpt2-vocab.json \
--dataset-impl mmap \
--tokenizer-type GPT2BPETokenizer \
--merge-file ${BASE_DATA_PATH}/gpt2-merges.txt \
--append-eod

python tools/preprocess_data.py \
--input ${INPUT_PATH}/dataset_4.json \
--output-prefix ${OUTPUT_PATH}/dataset-4 \
--vocab ${BASE_DATA_PATH}/gpt2-vocab.json \
--dataset-impl mmap \
--tokenizer-type GPT2BPETokenizer \
--merge-file ${BASE_DATA_PATH}/gpt2-merges.txt \
--append-eod


DIR=`pwd`
DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'`
mkdir -p ${BASE_DATA_PATH}/logs

DATASET_0="${OUTPUT_PATH}/dataset-0_text_document"
DATASET_1="${OUTPUT_PATH}/dataset-1_text_document"
DATASET_2="${OUTPUT_PATH}/dataset-2_text_document"
DATASET_3="${OUTPUT_PATH}/dataset-3_text_document"
DATASET_4="${OUTPUT_PATH}/dataset-4_text_document"
DATASET="0.1 ${DATASET_0} 0.25 ${DATASET_1} 0.2 ${DATASET_2} 0.15 ${DATASET_3} 0.3 ${DATASET_4}"
VALID_DATASET="1.0 ${DATASET_0}"
VOCAB_PATH=${BASE_DATA_PATH}/gpt2-vocab.json
MERGE_PATH=${BASE_DATA_PATH}/gpt2-merges.txt

CONFIG_JSON="${EXP_PATH}/ds_config.json"
touch $CONFIG_JSON

USE_DEEPSPEED=1
ZERO_STAGE=0

#super small model
TP=1
PP=1
HIDDEN=256
LAYERS=2
SEQ=128
GLOBAL_BATCH=4
WORKER_STR=""

MICRO_BATCH=4

while [[ $# -gt 0 ]]
do
key="$1"
case $key in
--no-deepspeed)
USE_DEEPSPEED=0;
shift
;;
-z|--zero-stage)
ZERO_STAGE=$2;
shift
;;
*)
echo "Unknown argument(s)"
usage
exit 1
shift
;;
esac
done

options=" \
--tensor-model-parallel-size $TP \
--pipeline-model-parallel-size $PP \
--num-layers $LAYERS \
--hidden-size $HIDDEN \
--num-attention-heads 32 \
--seq-length $SEQ \
--loss-scale 12 \
--max-position-embeddings $SEQ \
--micro-batch-size $MICRO_BATCH \
--global-batch-size $GLOBAL_BATCH \
--train-iters 1000 \
--lr 6.0e-5 \
--min-lr 6.0e-6 \
--lr-decay-style cosine \
--log-interval 1 \
--eval-iters 100 \
--eval-interval 40 \
--data-path ${DATASET} \
--valid-data-path ${VALID_DATASET} \
--vocab-file ${VOCAB_PATH} \
--merge-file ${MERGE_PATH} \
--save-interval 1000 \
--split 98,2,0 \
--clip-grad 1.0 \
--weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--init-method-std 0.006 \
--fp16 \
--checkpoint-activations
"


if [[ ${USE_DEEPSPEED} -eq 1 ]]; then
echo "Using DeepSpeed"
options="${options} \
--deepspeed \
--deepspeed_config=${CONFIG_JSON} \
--zero-stage=${ZERO_STAGE} \
--deepspeed-activation-checkpointing \
"
fi


cat <<EOT > $CONFIG_JSON
{
"train_batch_size" : $GLOBAL_BATCH,
"train_micro_batch_size_per_gpu": $MICRO_BATCH,
"steps_per_print": 1,
"zero_optimization": {
"stage": $ZERO_STAGE
},
"gradient_clipping": 1.0,
"prescale_gradients": true,
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 500,
"hysteresis": 2,
"min_loss_scale": 1,
"initial_scale_power": 12
},
"wall_clock_breakdown" : true
}
EOT

# run_cmd="deepspeed $WORKER_STR ${DIR}/test_sampling.py $@ ${options}"
run_cmd="deepspeed $WORKER_STR pretrain_gpt.py $@ ${options}"

echo ${run_cmd}
eval ${run_cmd}

set +x