diff --git a/README.md b/README.md index 127091f83..4b3e08ce2 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,11 @@ python tools/preprocess_data.py \ The output will be two files named, in this case, `my-bert_text_sentence.bin` and `my-bert_text_sentence.idx`. The `--data-path` specified in later BERT training is the full path and new filename, but without the file extension. +For T5 use the same preprocessing as BERT, perhaps renaming it to: +
+ --output-prefix my-t5 \ ++ Some minor modifications are required for GPT data preprocessing, namely, the addition of a merge table, an end-of-document token, removal of sentence splitting, and a change to the tokenizer type:
python tools/preprocess_data.py \
@@ -247,13 +252,14 @@ T5_ARGS="--num-layers 24 \
--micro-batch-size 16 \
--global-batch-size 2048 \
--vocab-file $VOCAB_FILE \
+ --vocab-extra-ids 100 \
--split 949,50,1 \
--fp16"
OUTPUT_ARGS=<same as those in BERT pretraining above>
python pretrain_t5.py \
- $BERT_ARGS \
+ $T5_ARGS \
$OUTPUT_ARGS \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
diff --git a/examples/create_embeddings.sh b/examples/create_embeddings.sh
deleted file mode 100644
index 59a5839f7..000000000
--- a/examples/create_embeddings.sh
+++ /dev/null
@@ -1,32 +0,0 @@
-#!/bin/bash
-
-# Compute embeddings for each entry of a given dataset (e.g. Wikipedia)
-
-RANK=0
-WORLD_SIZE=1
-
-# Wikipedia data can be downloaded from the following link:
-# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
-EVIDENCE_DATA_DIR=
-EMBEDDING_PATH=
-CHECKPOINT_PATH=
-
-python tools/create_doc_index.py \
- --num-layers 12 \
- --hidden-size 768 \
- --num-attention-heads 12 \
- --tensor-model-parallel-size 1 \
- --micro-batch-size 128 \
- --checkpoint-activations \
- --seq-length 512 \
- --retriever-seq-length 256 \
- --max-position-embeddings 512 \
- --load ${CHECKPOINT_PATH} \
- --evidence-data-path ${EVIDENCE_DATA_DIR} \
- --embedding-path ${EMBEDDING_PATH} \
- --indexer-log-interval 1000 \
- --indexer-batch-size 128 \
- --vocab-file bert-vocab.txt \
- --num-workers 2 \
- --fp16
-
diff --git a/examples/evaluate_ict_zeroshot_nq.sh b/examples/evaluate_retriever_nq.sh
similarity index 75%
rename from examples/evaluate_ict_zeroshot_nq.sh
rename to examples/evaluate_retriever_nq.sh
index e1ce45a93..8b87be302 100644
--- a/examples/evaluate_ict_zeroshot_nq.sh
+++ b/examples/evaluate_retriever_nq.sh
@@ -1,19 +1,19 @@
#!/bin/bash
# Evaluate natural question test data given Wikipedia embeddings and pretrained
-# ICT model
+# ICT model or a finetuned model for Natural Question task
# Datasets can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
EVIDENCE_DATA_DIR=
EMBEDDING_PATH=
-CHECKPOINT_PATH=
+CHECKPOINT_PATH=
-QA_FILE=
+QA_FILE=
python tasks/main.py \
- --task ICT-ZEROSHOT-NQ \
+ --task RETRIEVER-EVAL \
--tokenizer-type BertWordPieceLowerCase \
--num-layers 12 \
--hidden-size 768 \
@@ -29,8 +29,10 @@ python tasks/main.py \
--retriever-seq-length 256 \
--vocab-file bert-vocab.txt\
--qa-data-test ${QA_FILE} \
- --num-workers 2 \
--faiss-use-gpu \
--retriever-report-topk-accuracies 1 5 20 100 \
- --fp16
+ --fp16 \
+ --indexer-log-interval 1000 \
+ --indexer-batch-size 128
+
diff --git a/examples/finetune_retriever_distributed.sh b/examples/finetune_retriever_distributed.sh
new file mode 100755
index 000000000..535a2e053
--- /dev/null
+++ b/examples/finetune_retriever_distributed.sh
@@ -0,0 +1,56 @@
+#!/bin/bash
+
+# Finetune a BERT or pretrained ICT model using Google natural question data
+# Datasets can be downloaded from the following link:
+# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
+
+WORLD_SIZE=8
+
+DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
+ --nnodes 1 \
+ --node_rank 0 \
+ --master_addr localhost \
+ --master_port 6000"
+
+CHECKPOINT_PATH=
+
+# Load either of the below
+BERT_LOAD_PATH=
+PRETRAINED_CHECKPOINT=
+
+python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
+ --task RET-FINETUNE-NQ \
+ --train-with-neg \
+ --train-hard-neg 1 \
+ --pretrained-checkpoint ${PRETRAINED_CHECKPOINT} \
+ --num-layers 12 \
+ --hidden-size 768 \
+ --num-attention-heads 12 \
+ --tensor-model-parallel-size 1 \
+ --tokenizer-type BertWordPieceLowerCase \
+ --train-data nq-train.json \
+ --valid-data nq-dev.json \
+ --save ${CHECKPOINT_PATH} \
+ --load ${CHECKPOINT_PATH} \
+ --vocab-file bert-vocab.txt \
+ --bert-load ${BERT_LOAD_PATH} \
+ --save-interval 5000 \
+ --log-interval 10 \
+ --eval-interval 20000 \
+ --eval-iters 100 \
+ --indexer-log-interval 1000 \
+ --faiss-use-gpu \
+ --DDP-impl torch \
+ --fp16 \
+ --retriever-report-topk-accuracies 1 5 10 20 100 \
+ --seq-length 512 \
+ --retriever-seq-length 256 \
+ --max-position-embeddings 512 \
+ --retriever-score-scaling \
+ --epochs 80 \
+ --micro-batch-size 8 \
+ --eval-micro-batch-size 16 \
+ --indexer-batch-size 128 \
+ --lr 2e-5 \
+ --lr-warmup-fraction 0.01 \
+ --weight-decay 1e-1
diff --git a/examples/pretrain_t5.sh b/examples/pretrain_t5.sh
index 71fea8489..91fd5929b 100644
--- a/examples/pretrain_t5.sh
+++ b/examples/pretrain_t5.sh
@@ -15,7 +15,7 @@ python pretrain_t5.py \
--encoder-seq-length 512 \
--decoder-seq-length 128 \
--micro-batch-size 16 \
- --global-batch-size 2048 \
+ --global-batch-size 16 \
--max-position-embeddings 512 \
--train-iters 1000000 \
--lr-decay-iters 1000000 \
@@ -35,4 +35,5 @@ python pretrain_t5.py \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
- --fp16
+ --fp16 \
+ --vocab-extra-ids 100
diff --git a/examples/pretrain_t5_distributed.sh b/examples/pretrain_t5_distributed.sh
index 778b4ad2a..2beb1cdac 100644
--- a/examples/pretrain_t5_distributed.sh
+++ b/examples/pretrain_t5_distributed.sh
@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--encoder-seq-length 512 \
--decoder-seq-length 128 \
--micro-batch-size 16 \
- --global-batch-size 2048 \
+ --global-batch-size 128 \
--max-position-embeddings 512 \
--train-iters 1000000 \
--lr-decay-iters 1000000 \
@@ -44,4 +44,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
- --fp16
+ --fp16 \
+ --vocab-extra-ids 100
diff --git a/examples/pretrain_t5_distributed_with_mp.sh b/examples/pretrain_t5_distributed_with_mp.sh
index 9be70393d..23f1cd664 100644
--- a/examples/pretrain_t5_distributed_with_mp.sh
+++ b/examples/pretrain_t5_distributed_with_mp.sh
@@ -24,8 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--encoder-seq-length 512 \
--decoder-seq-length 128 \
--micro-batch-size 16 \
- --global-batch-size 2048 \
- --seq-length 512 \
+ --global-batch-size 128 \
--max-position-embeddings 512 \
--train-iters 1000000 \
--lr-decay-iters 1000000 \
@@ -45,4 +44,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
- --fp16
+ --fp16 \
+ --vocab-extra-ids 100
diff --git a/megatron/arguments.py b/megatron/arguments.py
index 0b21ffbdb..49f8b253a 100644
--- a/megatron/arguments.py
+++ b/megatron/arguments.py
@@ -440,6 +440,8 @@ def _add_training_args(parser):
help='Run optimizer on CPU')
group.add_argument('--cpu_torch_adam', action='store_true',
help='Use Torch Adam as optimizer on CPU.')
+ group.add_argument('--codecarbon-dir', type=str, default=None,
+ help='Write CodeCarbon logs to this directory.')
return parser
diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py
index 3cc6a8e2e..49c4b6533 100644
--- a/megatron/checkpointing.py
+++ b/megatron/checkpointing.py
@@ -60,8 +60,8 @@ def _compare(arg_name, old_arg_name=None):
_compare('num_layers')
_compare('hidden_size')
_compare('num_attention_heads')
- _compare('max_position_embeddings')
if args.vocab_file:
+ _compare('max_position_embeddings')
_compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size')
_compare('tokenizer_type')
diff --git a/megatron/indexer.py b/megatron/indexer.py
index c0d1ca7de..d2ff9e36f 100644
--- a/megatron/indexer.py
+++ b/megatron/indexer.py
@@ -1,15 +1,16 @@
import sys
+import time
import torch
import torch.distributed as dist
-from megatron import get_args
+from megatron import get_args, print_rank_0
from megatron import mpu
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, OpenRetreivalDataStore
-from megatron.model.biencoder_model import biencoder_model_provider
+from megatron.model.biencoder_model import get_model_provider
from megatron.training import get_model
@@ -29,7 +30,6 @@ def __init__(self):
# need to know whether we're using a REALM checkpoint (args.load)
# or ICT checkpoint
assert not (args.load and args.ict_load)
- #self.using_realm_chkpt = args.ict_load is None
self.log_interval = args.indexer_log_interval
self.batch_size = args.indexer_batch_size
@@ -47,8 +47,8 @@ def load_attributes(self):
if self.biencoder_shared_query_context_model:
only_context_model = False
- model = get_model(lambda: biencoder_model_provider(only_context_model \
- = only_context_model, biencoder_shared_query_context_model = \
+ model = get_model(get_model_provider(only_context_model=\
+ only_context_model, biencoder_shared_query_context_model=\
self.biencoder_shared_query_context_model))
self.model = load_biencoder_checkpoint(model,
@@ -85,6 +85,7 @@ def build_and_save_index(self):
"""
assert len(self.model) == 1
unwrapped_model = self.model[0]
+
while not hasattr(unwrapped_model, 'embed_text'):
unwrapped_model = unwrapped_model.module
@@ -103,6 +104,7 @@ def build_and_save_index(self):
context_logits = unwrapped_model.embed_text(
unwrapped_model.context_model, context_tokens, context_mask,
context_types)
+
context_logits = detach(context_logits)
row_id = detach(row_id)
diff --git a/megatron/learning_rates.py b/megatron/learning_rates.py
index d200bdb17..c53af8d54 100644
--- a/megatron/learning_rates.py
+++ b/megatron/learning_rates.py
@@ -87,7 +87,7 @@ def get_lr(self):
else:
raise Exception('{} decay style is not supported.'.format(
self.decay_style))
-
+
return self.min_lr + coeff * delta_lr
diff --git a/megatron/model/biencoder_model.py b/megatron/model/biencoder_model.py
index 51ac0a060..e1f94bf1c 100644
--- a/megatron/model/biencoder_model.py
+++ b/megatron/model/biencoder_model.py
@@ -15,11 +15,30 @@
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule
+def get_model_provider(only_query_model=False, only_context_model=False,
+ biencoder_shared_query_context_model=False):
+
+ def model_provider(pre_process=True, post_process=True):
+ """Build the model."""
+
+ print_rank_0('building Bienoder model ...')
+ model = biencoder_model_provider(only_query_model=only_query_model,
+ only_context_model = only_context_model,
+ biencoder_shared_query_context_model = \
+ biencoder_shared_query_context_model,
+ pre_process=pre_process, post_process=post_process)
+
+ return model
+
+ return model_provider
+
+
def biencoder_model_provider(only_query_model=False,
only_context_model=False,
- biencoder_shared_query_context_model=False):
+ biencoder_shared_query_context_model=False,
+ pre_process=True,
+ post_process=True):
"""Build the model."""
- args = get_args()
assert mpu.get_tensor_model_parallel_world_size() == 1 and \
mpu.get_pipeline_model_parallel_world_size() == 1, \
@@ -35,7 +54,9 @@ def biencoder_model_provider(only_query_model=False,
only_query_model=only_query_model,
only_context_model=only_context_model,
biencoder_shared_query_context_model=\
- biencoder_shared_query_context_model)
+ biencoder_shared_query_context_model,
+ pre_process=pre_process,
+ post_process=post_process)
return model
@@ -48,13 +69,17 @@ def __init__(self,
parallel_output=True,
only_query_model=False,
only_context_model=False,
- biencoder_shared_query_context_model=False):
+ biencoder_shared_query_context_model=False,
+ pre_process=True,
+ post_process=True):
super(BiEncoderModel, self).__init__()
args = get_args()
bert_kwargs = dict(
num_tokentypes=num_tokentypes,
- parallel_output=parallel_output)
+ parallel_output=parallel_output,
+ pre_process=pre_process,
+ post_process=post_process)
self.biencoder_shared_query_context_model = \
biencoder_shared_query_context_model
@@ -78,6 +103,13 @@ def __init__(self,
self.context_model = PretrainedBertModel(**bert_kwargs)
self._context_key = 'context_model'
+ def set_input_tensor(self, input_tensor):
+ """See megatron.model.transformer.set_input_tensor()"""
+ # this is just a placeholder and will be needed when model
+ # parallelism will be used
+ # self.language_model.set_input_tensor(input_tensor)
+ return
+
def forward(self, query_tokens, query_attention_mask, query_types,
context_tokens, context_attention_mask, context_types):
"""Run a forward pass for each of the models and
@@ -217,7 +249,7 @@ class PretrainedBertModel(MegatronModule):
learned information retrieval."""
def __init__(self, num_tokentypes=2,
- parallel_output=True):
+ parallel_output=True, pre_process=True, post_process=True):
super(PretrainedBertModel, self).__init__()
args = get_args()
@@ -225,6 +257,8 @@ def __init__(self, num_tokentypes=2,
self.pad_id = tokenizer.pad
self.biencoder_projection_dim = args.biencoder_projection_dim
self.parallel_output = parallel_output
+ self.pre_process = pre_process
+ self.post_process = post_process
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(
args.init_method_std, args.num_layers)
@@ -234,7 +268,9 @@ def __init__(self, num_tokentypes=2,
add_pooler=False,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
- scaled_init_method=scaled_init_method)
+ scaled_init_method=scaled_init_method,
+ pre_process=self.pre_process,
+ post_process=self.post_process)
if args.biencoder_projection_dim > 0:
self.projection_enc = get_linear_layer(args.hidden_size,
@@ -247,7 +283,6 @@ def forward(self, input_ids, attention_mask, tokentype_ids=None):
#extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids = bert_position_ids(input_ids)
-
lm_output = self.language_model(input_ids,
position_ids,
extended_attention_mask,
@@ -285,7 +320,7 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
- print_rank_0("loading BERT weights")
+ print_rank_0("loading pretrained weights")
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
diff --git a/megatron/model/vit_model.py b/megatron/model/vit_model.py
index 84a52a829..a1a86cfff 100644
--- a/megatron/model/vit_model.py
+++ b/megatron/model/vit_model.py
@@ -50,11 +50,11 @@ def __init__(self, hidden_size, num_classes):
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
- x = hidden_states[:, sequence_index, :]
- x = self.dense_in(x)
- x = torch.tanh(x)
- x = self.dense_out(x)
- return x
+ hidden_state = hidden_states[:, sequence_index, :]
+ dense_in_result = self.dense_in(hidden_state)
+ tanh_result = torch.tanh(dense_in_result)
+ dense_out_result = self.dense_out(tanh_result)
+ return dense_out_result
def twod_interpolate_position_embeddings_hook(
@@ -122,8 +122,12 @@ def twod_interpolate_position_embeddings_hook(
class VitModel(MegatronModule):
"""Vision Transformer Model."""
- def __init__(self, num_classes, finetune=False):
- super(VitModel, self).__init__()
+ def __init__(self,
+ num_classes,
+ finetune=False,
+ pre_process=True,
+ post_process=True):
+ super(VitModel, self).__init__(share_word_embeddings=False)
args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
@@ -136,6 +140,8 @@ def __init__(self, num_classes, finetune=False):
args.init_method_std, args.num_layers
)
+ self.pre_process = pre_process
+ self.post_process = post_process
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.patch_dim = args.patch_dim
@@ -148,63 +154,81 @@ def __init__(self, num_classes, finetune=False):
self.seq_length = self.num_patches + 1
self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels
- # cls_token
- self.cls_token = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size))
- torch.nn.init.zeros_(self.cls_token)
+ if self.pre_process:
+ # cls_token
+ self.cls_token = torch.nn.Parameter(
+ torch.randn(1, 1, self.hidden_size)
+ )
+ torch.nn.init.zeros_(self.cls_token)
- # Linear encoder
- self.linear_encoder = torch.nn.Linear(
- self.flatten_dim, self.hidden_size
- )
+ # Linear encoder
+ self.linear_encoder = torch.nn.Linear(
+ self.flatten_dim, self.hidden_size
+ )
- # embedding
- self.position_embeddings = torch.nn.Embedding(
- self.seq_length, self.hidden_size
- )
- init_method_normal(args.init_method_std)(
- self.position_embeddings.weight
- )
- self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
+ # embedding
+ self.position_embeddings = torch.nn.Embedding(
+ self.seq_length, self.hidden_size
+ )
+ init_method_normal(args.init_method_std)(
+ self.position_embeddings.weight
+ )
+ self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
- self.position_embeddings._register_load_state_dict_pre_hook(
- twod_interpolate_position_embeddings_hook
- )
+ self.position_embeddings._register_load_state_dict_pre_hook(
+ twod_interpolate_position_embeddings_hook
+ )
- self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
+ self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
# Transformer
self.transformer = ParallelTransformer(
- self.init_method, self.scaled_init_method
+ self.init_method,
+ self.scaled_init_method,
+ pre_process=self.pre_process,
+ post_process=self.post_process
)
- # MLP head
- if not self.finetune:
- self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
- else:
- self.class_head = get_linear_layer(
- self.hidden_size, num_classes, torch.nn.init.zeros_
+ if self.post_process:
+ # MLP head
+ if not self.finetune:
+ self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
+ else:
+ self.class_head = get_linear_layer(
+ self.hidden_size, num_classes, torch.nn.init.zeros_
+ )
+
+ def set_input_tensor(self, input_tensor):
+ """See megatron.model.transformer.set_input_tensor()"""
+ self.transformer.set_input_tensor(input_tensor)
+
+ def forward(self, input):
+
+ if self.pre_process:
+ rearranged_input = einops.rearrange(
+ input,
+ "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
+ p1=self.patch_dim,
+ p2=self.patch_dim,
)
- def forward(self, x):
- x = einops.rearrange(
- x,
- "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
- p1=self.patch_dim,
- p2=self.patch_dim,
- )
+ assert rearranged_input.dtype == torch.half
+ encoder_output = self.linear_encoder(rearranged_input)
+ cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1)
+ concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1)
- assert x.dtype == torch.half
- x = self.linear_encoder(x)
- cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
- x = torch.cat((cls_tokens, x), dim=1)
+ token_embeddings = concatenated_tokens + \
+ self.position_embeddings(self.position_ids)
+ hidden_states = self.embedding_dropout(token_embeddings)
+ else:
+ hidden_states = input
- x = x + self.position_embeddings(self.position_ids)
- x = self.embedding_dropout(x)
- x = self.transformer(x, None)
+ hidden_states = self.transformer(hidden_states, None)
- if not self.finetune:
- x = self.mlp_head(x)
- else:
- x = self.class_head(x[:, 0, :])
+ if self.post_process:
+ if not self.finetune:
+ hidden_states = self.mlp_head(hidden_states)
+ else:
+ hidden_states = self.class_head(hidden_states[:, 0, :])
- return x
+ return hidden_states
diff --git a/megatron/tokenizer/bert_tokenization.py b/megatron/tokenizer/bert_tokenization.py
index a3aa6d907..99f9a8795 100644
--- a/megatron/tokenizer/bert_tokenization.py
+++ b/megatron/tokenizer/bert_tokenization.py
@@ -181,6 +181,35 @@ def convert_tokens_to_ids(self, tokens):
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
+ @staticmethod
+ def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
+ """ Converts a sequence of tokens (string) in a single string. """
+
+ def clean_up_tokenization(out_string):
+ """ Clean up a list of simple English tokenization artifacts
+ like spaces before punctuations and abreviated forms.
+ """
+ out_string = (
+ out_string.replace(" .", ".")
+ .replace(" ?", "?")
+ .replace(" !", "!")
+ .replace(" ,", ",")
+ .replace(" ' ", "'")
+ .replace(" n't", "n't")
+ .replace(" 'm", "'m")
+ .replace(" 's", "'s")
+ .replace(" 've", "'ve")
+ .replace(" 're", "'re")
+ )
+ return out_string
+
+ text = ' '.join(tokens).replace(' ##', '').strip()
+ if clean_up_tokenization_spaces:
+ clean_text = clean_up_tokenization(text)
+ return clean_text
+ else:
+ return text
+
def vocab_size(self):
return len(self.vocab)
diff --git a/megatron/training.py b/megatron/training.py
index 5adef3c11..0cdb0d1bc 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -110,6 +110,17 @@ def pretrain(train_valid_test_dataset_provider,
args = get_args()
timers = get_timers()
+ # XXX: quick hack-in for now - add a clean wrapper later
+ if args.codecarbon_dir is not None:
+ import codecarbon
+ from pathlib import Path
+ print("CC START")
+
+ Path(args.codecarbon_dir).mkdir(parents=True, exist_ok=True)
+ output_file = f"emissions-{args.rank:03d}.csv"
+ cc_tracker = codecarbon.EmissionsTracker(output_dir=args.codecarbon_dir, output_file=output_file)
+ cc_tracker.start()
+
# Model, optimizer, and learning rate.
timers('model-and-optimizer-setup').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
@@ -162,6 +173,12 @@ def pretrain(train_valid_test_dataset_provider,
test_data_iterator, model,
0, True)
+
+ # XXX: clean up
+ if args.codecarbon_dir is not None:
+ print("CC STOP")
+ cc_tracker.stop()
+
def update_train_iters(args):
# For iteration-based training, we don't need to do anything
@@ -838,7 +855,10 @@ def build_train_valid_test_data_iterators(
assert args.train_samples is None, \
'only backward compatiblity support for iteration-based training'
args.consumed_train_samples = args.iteration * args.global_batch_size
- if args.iteration > 0 and args.consumed_valid_samples == 0:
+ # it's possible that train was run, but not eval and it's valid if
+ # args.consumed_valid_samples == 0
+ # TODO: eval_interval could have changed between runs, so this might still be wrong
+ if args.iteration // args.eval_interval > 0:
assert args.train_samples is None, \
'only backward compatiblity support for iteration-based training'
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
diff --git a/pretrain_ict.py b/pretrain_ict.py
index 1438b3d57..79759250f 100644
--- a/pretrain_ict.py
+++ b/pretrain_ict.py
@@ -14,6 +14,8 @@
# limitations under the License.
"""Pretrain BERT for Inverse Cloze Task"""
+
+from functools import partial
import math
import torch
@@ -31,13 +33,16 @@
from megatron.utils import average_losses_across_data_parallel_group
-def pretrain_ict_model_provider():
+def pretrain_ict_model_provider(pre_process=True, post_process=True):
args = get_args()
+
model = biencoder_model_provider(
only_context_model=False,
only_query_model=False,
biencoder_shared_query_context_model=\
- args.biencoder_shared_query_context_model)
+ args.biencoder_shared_query_context_model,
+ pre_process=pre_process, post_process=post_process)
+
return model
def get_group_world_size_rank():
@@ -77,25 +82,9 @@ def backward(ctx, grad_output):
output = output_list[rank].contiguous()
return output
-def forward_step(data_iterator, model, input_tensor):
- """Forward step."""
+def loss_func(output_tensor):
args = get_args()
- timers = get_timers()
-
- # Get the batch.
- timers('batch-generator').start()
- query_tokens, query_mask, \
- context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
- timers('batch-generator').stop()
-
- # Query and Context Types
- query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
- context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
-
- # Forward model.
- query_logits, context_logits = model(query_tokens, query_mask,
- query_types, context_tokens,
- context_mask, context_types)
+ query_logits, context_logits = output_tensor
micro_batch_size = query_logits.shape[0]
# recall we assert that tensor_model_parallel_size == 1
@@ -137,6 +126,28 @@ def topk_accuracy(k):
return loss, stats_dict
+
+def forward_step(data_iterator, model):
+ """Forward step."""
+ args = get_args()
+ timers = get_timers()
+
+ # Get the batch.
+ timers('batch-generator').start()
+ query_tokens, query_mask, \
+ context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
+ timers('batch-generator').stop()
+
+ # Query and Context Types
+ query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
+ context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
+
+ # Forward model.
+ output_tensor = model(query_tokens, query_mask, query_types, context_tokens,
+ context_mask, context_types)
+
+ return output_tensor, partial(loss_func)
+
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid and test datasets."""
args = get_args()
diff --git a/pretrain_vit.py b/pretrain_vit.py
index 16ec10439..7770c68d5 100644
--- a/pretrain_vit.py
+++ b/pretrain_vit.py
@@ -17,19 +17,22 @@
import torch
import torch.nn.functional as F
+from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vit_model import VitModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
-def model_provider():
+def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0("building VIT model ...")
args = get_args()
- model = VitModel(num_classes=args.num_classes)
+ model = VitModel(num_classes=args.num_classes,
+ pre_process=pre_process,
+ post_process=post_process)
return model
def get_batch(data_iterator):
@@ -42,10 +45,21 @@ def get_batch(data_iterator):
return images, labels
-def forward_step(data_iterator, model, input_tensor):
+def loss_func(labels, output_tensor):
+ logits = output_tensor.contiguous().float()
+ loss = F.cross_entropy(logits, labels)
+
+ outputs = torch.argmax(logits, -1)
+ correct = (outputs == labels).float()
+ accuracy = torch.mean(correct)
+
+ averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])
+
+ return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
+
+def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
- assert input_tensor is None
# Get the batch.
timers("batch-generator").start()
@@ -56,17 +70,9 @@ def forward_step(data_iterator, model, input_tensor):
timers("batch-generator").stop()
# Forward model. lm_labels
- logits = model(images).contiguous().float()
- loss = F.cross_entropy(logits, labels)
-
- outputs = torch.argmax(logits, -1)
- correct = (outputs == labels).float()
- accuracy = torch.mean(correct)
-
- averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])
-
- return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
+ output_tensor = model(images)
+ return output_tensor, partial(loss_func, labels)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
diff --git a/requirements.txt b/requirements.txt
index 1f7389c3e..cd79071d0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,11 @@
+# megatron requirements
+
pybind11
torch
six
regex
numpy
+
+# big-science requirements
+
+codecarbon==1.2.0
diff --git a/tasks/finetune_utils.py b/tasks/finetune_utils.py
index 918417b41..9411b1849 100644
--- a/tasks/finetune_utils.py
+++ b/tasks/finetune_utils.py
@@ -16,10 +16,10 @@
"""Finetune utilities."""
from functools import partial
-
+import sys
import torch
-from megatron import get_args
+from megatron import get_args, get_num_microbatches
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
@@ -80,7 +80,8 @@ def _cross_entropy_forward_step(batch, model):
return output_tensor, partial(cross_entropy_loss_func, labels)
-def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
+def build_data_loader(dataset, micro_batch_size, num_workers, drop_last,
+ task_collate_fn=None):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
@@ -96,7 +97,8 @@ def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
shuffle=False,
num_workers=num_workers,
drop_last=drop_last,
- pin_memory=True)
+ pin_memory=True,
+ collate_fn=task_collate_fn)
return data_loader
@@ -112,21 +114,24 @@ def _build_infinite_size_dataloader(dataloader):
iterator = dataloader.__iter__()
-def _build_train_valid_dataloaders(train_dataset, valid_dataset):
+def _build_train_valid_dataloaders(train_dataset, valid_dataset,
+ task_collate_fn=None):
"""Traing and validation dataloaders."""
args = get_args()
print_rank_0('building train and validation dataloaders ...')
# Training dataset.
train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
- args.num_workers, not args.keep_last)
+ args.num_workers, not args.keep_last,
+ task_collate_fn)
# Set the training iterations.
args.train_iters_per_epoch = len(train_dataloader)
args.train_iters = args.epochs * args.train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop.
valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
- args.num_workers, not args.keep_last)
+ args.num_workers, not args.keep_last,
+ task_collate_fn)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
# Now that we've built the data loaders, set batch_size arguments
@@ -154,6 +159,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
args = get_args()
timers = get_timers()
+ assert get_num_microbatches() == 1, "finetuning with gradient accumulation doesn't currently work"
+
# Turn on training mode which enables dropout.
for m in model:
m.train()
@@ -188,6 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
# Train for one step.
out = train_step(forward_step, batch, model, optimizer, lr_scheduler)
+
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
iteration += 1
@@ -209,9 +217,11 @@ def _train(model, optimizer, lr_scheduler, forward_step,
optimizer, lr_scheduler)
# Checkpointing
+ saved_checkpoint = False
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
+ saved_checkpoint = True
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0:
@@ -220,6 +230,14 @@ def _train(model, optimizer, lr_scheduler, forward_step,
valid_dataloader, model,
iteration, False)
+ # Exiting based on iterations
+ if args.exit_interval and iteration % args.exit_interval == 0:
+ if not saved_checkpoint:
+ save_checkpoint(iteration, model, optimizer, lr_scheduler)
+ torch.distributed.barrier()
+ print_rank_0('exiting program at iteration {}'.format(iteration))
+ sys.exit()
+
# Checkpointing at the end of each epoch.
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
@@ -231,7 +249,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
def finetune(train_valid_datasets_provider, model_provider,
forward_step=_cross_entropy_forward_step,
- end_of_epoch_callback_provider=None):
+ end_of_epoch_callback_provider=None,
+ task_collate_fn=None):
"""Main finetune function used across all tasks."""
args = get_args()
timers = get_timers()
@@ -244,7 +263,7 @@ def finetune(train_valid_datasets_provider, model_provider,
if args.epochs > 0:
train_dataset, valid_dataset = train_valid_datasets_provider()
train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
- train_dataset, valid_dataset)
+ train_dataset, valid_dataset, task_collate_fn)
else:
args.train_iters = 0
timers('train/valid/test dataset/dataloder').stop()
@@ -268,8 +287,11 @@ def finetune(train_valid_datasets_provider, model_provider,
if args.iteration == 0 and args.pretrained_checkpoint is not None:
original_load = args.load
args.load = args.pretrained_checkpoint
+ original_rng = args.no_load_rng
+ args.no_load_rng = True
_ = load_checkpoint(model, None, None)
args.load = original_load
+ args.no_load_rng = original_rng
# This is critical when only model is loaded. We should make sure
# main parameters are also updated.
optimizer.reload_model_params()
diff --git a/tasks/main.py b/tasks/main.py
index f5bd5ad69..6d8fc8f5f 100644
--- a/tasks/main.py
+++ b/tasks/main.py
@@ -62,6 +62,29 @@ def get_tasks_args(parser):
group.add_argument('--faiss-topk-retrievals', type=int, default=100,
help='Number of blocks to use as top-k during retrieval')
+ # finetune for retriever
+ group.add_argument('--eval-micro-batch-size', type=int, default=None,
+ help='Eval Batch size per model instance (local batch '
+ 'size). Global batch size is local batch size '
+ 'times data parallel size.')
+ group.add_argument('--train-with-neg', action='store_true',
+ help='Whether to use negative examples during model '
+ 'training')
+ group.add_argument('--train-hard-neg', type=int, default=0,
+ help='Number of hard negative exmaples to use during '
+ 'training')
+
+
+ # parameters for Av.rank validation method
+ # Following options/arguments have been taken directly from DPR codebase
+ group.add_argument('--val-av-rank-hard-neg', type=int, default=30,
+ help='Av.rank validation: how many hard negatives to'
+ ' take from each question pool')
+ group.add_argument('--val-av-rank-other-neg', type=int, default=30,
+ help='Av.rank validation: how many other negatives to'
+ ' take from each question pool')
+
+
return parser
@@ -81,8 +104,10 @@ def get_tasks_args(parser):
from glue.finetune import main
elif args.task in ['LAMBADA', 'WIKITEXT103']:
from zeroshot_gpt.evaluate import main
- elif args.task in ['ICT-ZEROSHOT-NQ']:
+ elif args.task in ['ICT-ZEROSHOT-NQ', 'RETRIEVER-EVAL']:
from orqa.evaluate_orqa import main
+ elif args.task in ['RET-FINETUNE-NQ']:
+ from orqa.supervised.finetune import main
else:
raise NotImplementedError('Task {} is not implemented.'.format(
args.task))
diff --git a/tasks/orqa/README.md b/tasks/orqa/README.md
new file mode 100644
index 000000000..a8e8f8e6f
--- /dev/null
+++ b/tasks/orqa/README.md
@@ -0,0 +1,36 @@
+## End-to-End Training of Neural Retrievers for Open-Domain Question Answering
+
+Below we present the steps to run unsupervised and supervised trainining and evaluation of the retriever for [open domain question answering](https://arxiv.org/abs/2101.00408).
+
+## Retriever Training
+
+#### Unsupervised pretraining
+1. Use `tools/preprocess_data.py` to preprocess the dataset for Inverse Cloze Task (ICT), which we call unsupervised pretraining. This script takes as input a corpus in loose JSON format and creates fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block and multiple blocks per document. Run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document and another with the body.
+
+
+python tools/preprocess_data.py \
+ --input /path/to/corpus.json \
+ --json-keys text title \
+ --split-sentences \
+ --tokenizer-type BertWordPieceLowerCase \
+ --vocab-file /path/to/vocab.txt \
+ --output-prefix corpus_indexed \
+ --workers 10
+
+
+2. The [`examples/pretrain_ict.sh`](../../examples/pretrain_ict.sh) script runs a single GPU 217M parameter biencoder model for ICT retriever training. Single GPU training is primarily intended for debugging purposes, as the code is developed for distributed training. The script uses a pretrained BERT model and we use a total of batch size of 4096 for the ICT training.
+
+3. Evaluate the pretrained ICT model using [`examples/evaluate_retriever_nq.sh`](../../examples/evaluate_retriever_nq.sh) for [Google's Natural Questions Open dataset](https://arxiv.org/pdf/1906.00300.pdf).
+
+#### Supervised finetuning
+
+1. Use the above pretrained ICT model to finetune using [Google's Natural Questions Open dataset](https://github.com/google-research/language/tree/master/language/orqa). The script [`examples/finetune_retriever_distributed.sh`](../../examples/finetune_retriever_distributed.sh) provides an example for how to perform the training. Our finetuning process includes retriever score scaling and longer training (80 epochs) on top [DPR training](https://arxiv.org/abs/2004.04906).
+
+2. Evaluate the finetuned model using the same evaluation script as mentioned above for the unsupervised model.
+
+More details on the retriever are available in [our paper](https://arxiv.org/abs/2101.00408).
+
+## Reader Training
+
+The reader component will be available soon.
+
diff --git a/tasks/orqa/evaluate_orqa.py b/tasks/orqa/evaluate_orqa.py
index 7e6b26923..87c59ea30 100644
--- a/tasks/orqa/evaluate_orqa.py
+++ b/tasks/orqa/evaluate_orqa.py
@@ -15,10 +15,8 @@
"""Main tasks functionality."""
-import os
-import sys
-
-from megatron import get_args
+from megatron import get_args, print_rank_0
+from megatron.indexer import IndexBuilder
from tasks.orqa.evaluate_utils import ORQAEvaluator
def main():
@@ -28,6 +26,20 @@ def main():
args = get_args()
+ """
+ Create a BlockData data structure by running an IndexBuilder over an
+ ICT Dataset and then evaluate on NQ task
+ """
+
+ print_rank_0("Starting index builder!")
+
+ index_builder = IndexBuilder()
+ index_builder.build_and_save_index()
+ print_rank_0("Build and save indices: done!")
+
+
+ print_rank_0("Starting evaluations!")
+
# Set up the model and evaluator
evaluator = ORQAEvaluator()
diff --git a/tasks/orqa/evaluate_utils.py b/tasks/orqa/evaluate_utils.py
index ebee03522..08b1e929b 100644
--- a/tasks/orqa/evaluate_utils.py
+++ b/tasks/orqa/evaluate_utils.py
@@ -18,13 +18,14 @@
from megatron import get_args, print_rank_0
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
-from tasks.orqa.natural_questions.nq import get_nq_dataset
-from tasks.orqa.natural_questions.nq import get_one_epoch_nq_dataloader
-from tasks.orqa.natural_questions.nq import process_nq_batch
-from tasks.orqa.natural_questions.qa_utils import calculate_matches
from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex
-from megatron.model.biencoder_model import biencoder_model_provider
+from megatron.model.biencoder_model import get_model_provider
from megatron.training import get_model
+from tasks.orqa.unsupervised.nq import get_nq_dataset
+from tasks.orqa.unsupervised.nq import get_one_epoch_nq_dataloader
+from tasks.orqa.unsupervised.nq import process_nq_batch
+from tasks.orqa.unsupervised.qa_utils import calculate_matches
+
class ORQAEvaluator(object):
def __init__(self):
@@ -44,9 +45,8 @@ def __init__(self):
if args.biencoder_shared_query_context_model:
only_query_model = False
- model = get_model(lambda: biencoder_model_provider(only_query_model=\
- only_query_model, biencoder_shared_query_context_model=\
- args.biencoder_shared_query_context_model))
+ model = get_model(get_model_provider(only_query_model=only_query_model,
+ biencoder_shared_query_context_model=args.biencoder_shared_query_context_model))
self.model = load_biencoder_checkpoint(model,
only_query_model=only_query_model)
diff --git a/tasks/orqa/supervised/data.py b/tasks/orqa/supervised/data.py
new file mode 100644
index 000000000..b45a842b6
--- /dev/null
+++ b/tasks/orqa/supervised/data.py
@@ -0,0 +1,300 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""ORQA dataset."""
+
+import json
+import random
+from abc import ABC
+from abc import abstractmethod
+
+import numpy as np
+from torch.utils.data import Dataset
+
+from megatron import print_rank_0, get_args
+from megatron.data.biencoder_dataset_utils import make_attention_mask
+
+def build_token_types_from_context_list(ctx_list, tokenizer, max_seq_length):
+ ctx_id_list, ctx_types_list = [], []
+ for context in ctx_list:
+ title_ids = tokenizer.tokenize(context['title'])
+ ctx_ids = tokenizer.tokenize(context['text'])
+ ctx_ids = title_ids + [tokenizer.sep_id] + ctx_ids
+
+ ctx_ids, ctx_types, _ = build_tokens_types_paddings_from_ids(ctx_ids,
+ max_seq_length, tokenizer.cls,
+ tokenizer.sep, tokenizer.pad)
+ ctx_id_list.append(ctx_ids)
+ ctx_types_list.append(ctx_types)
+
+ return ctx_id_list, ctx_types_list
+
+
+def build_tokens_types_paddings_from_text(query, context,
+ tokenizer, max_seq_length):
+ """Build token types and paddings, trim if needed, and pad if needed."""
+
+ query_ids = tokenizer.tokenize(query)
+ query_ids, query_types, query_pad_mask = \
+ build_tokens_types_paddings_from_ids(query_ids, max_seq_length, \
+ tokenizer.cls, tokenizer.sep, tokenizer.pad)
+
+ # Appending the title of the context at front
+ extended_ctx_ids = None
+ if context is not None:
+ title_ids = tokenizer.tokenize(context['title'])
+ ctx_ids = tokenizer.tokenize(context['text'])
+ extended_ctx_ids = title_ids + [tokenizer.sep] + ctx_ids
+
+ ctx_ids, ctx_types, ctx_pad_mask = \
+ build_tokens_types_paddings_from_ids(extended_ctx_ids,
+ max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad)
+
+ return query_ids, query_types, query_pad_mask, \
+ ctx_ids, ctx_types, ctx_pad_mask
+
+
+# Similar code tasks/data_utils with some changes
+def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
+ cls_id, sep_id, pad_id):
+ """Build token types and paddings, trim if needed, and pad if needed."""
+ enc_ids = []
+ tokentypes_enc = []
+
+ # [CLS].
+ enc_ids.append(cls_id)
+ tokentypes_enc.append(0)
+
+ # A.
+ len_src = len(text_ids)
+ enc_ids.extend(text_ids)
+ tokentypes_enc.extend([0] * len_src)
+
+ # Cap the size.
+ if len(enc_ids) > max_seq_length - 1:
+ enc_ids = enc_ids[0: max_seq_length - 1]
+ tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
+
+ # [SEP].
+ enc_ids.append(sep_id)
+ tokentypes_enc.append(0)
+
+ num_tokens_enc = len(enc_ids)
+ # Padding.
+ padding_length = max_seq_length - len(enc_ids)
+ if padding_length > 0:
+ enc_ids.extend([pad_id] * padding_length)
+ tokentypes_enc.extend([pad_id] * padding_length)
+
+ pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length)
+ pad_mask = np.array(pad_mask, dtype=np.int64)
+
+ return enc_ids, tokentypes_enc, pad_mask
+
+
+def build_sample(query_ids, query_types, query_pad_mask,
+ ctx_ids, ctx_types, ctx_pad_mask, answers,
+ neg_ctx_id_list=None, neg_ctx_types_list=None,
+ include_neg=False):
+ """Convert to numpy and return a sample consumed by the batch producer."""
+
+ query_ids = np.array(query_ids, dtype=np.int64)
+ query_types = np.array(query_types, dtype=np.int64)
+ query_mask = make_attention_mask(query_ids, query_ids)
+
+ ctx_ids = np.array(ctx_ids, dtype=np.int64)
+ ctx_types = np.array(ctx_types, dtype=np.int64)
+ ctx_mask = make_attention_mask(ctx_ids, ctx_ids)
+
+ sample = ({
+ 'query': query_ids,
+ 'query_mask': query_mask,
+ 'query_types': query_types,
+ 'query_pad_mask': query_pad_mask,
+ 'context': ctx_ids,
+ 'context_mask': ctx_mask,
+ 'context_types': ctx_types,
+ 'context_pad_mask': ctx_pad_mask,
+ 'reference': answers
+ })
+
+ if include_neg:
+ neg_ctx_ids = np.array(neg_ctx_id_list, dtype=np.int64)
+ neg_ctx_id_types = np.array(neg_ctx_types_list, dtype=np.int64)
+ neg_ctx_mask = np.array([make_attention_mask(ids, ids) \
+ for ids in neg_ctx_ids], dtype=np.int64)
+
+ sample['neg_context'] = neg_ctx_ids
+ sample['neg_context_types'] = neg_ctx_id_types
+ sample['neg_context_mask'] = neg_ctx_mask
+
+ return sample
+
+
+class OpenRetrievalAbstractDataset(ABC, Dataset):
+ """Open Retrieval base dataset class."""
+
+ def __init__(self, task_name, dataset_name, datapaths, tokenizer, \
+ max_seq_length, evaluate=False):
+ # Store inputs.
+ args = get_args()
+ self.evaluate = evaluate
+ self.val_av_rank_hard_neg = args.val_av_rank_hard_neg
+ self.val_av_rank_other_neg = args.val_av_rank_other_neg
+ self.train_with_neg = args.train_with_neg
+ self.train_hard_neg = args.train_hard_neg
+
+ self.task_name = task_name
+ self.dataset_name = dataset_name
+ self.tokenizer = tokenizer
+ self.max_seq_length = max_seq_length
+ print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
+ self.dataset_name))
+ # Process the files.
+ string = ' > paths:'
+ for path in datapaths:
+ string += ' ' + path
+ print_rank_0(string)
+ self.samples = []
+ for datapath in datapaths:
+ self.samples.extend(self.process_samples_from_single_path(datapath))
+
+ args = get_args()
+ if args.sample_rate < 1: # subsample
+ k = int(len(self.samples) * args.sample_rate)
+ self.samples = random.sample(self.samples, k)
+
+ print_rank_0(' >> total number of samples: {}'.format(
+ len(self.samples)))
+
+ def __len__(self):
+ return len(self.samples)
+
+ def __getitem__(self, idx):
+ raw_sample = self.samples[idx]
+
+ query_ids, query_types, query_pad_mask, ctx_ids, ctx_types, \
+ ctx_pad_mask = build_tokens_types_paddings_from_text( \
+ raw_sample['question'], raw_sample['pos_context'], \
+ self.tokenizer, self.max_seq_length)
+
+ if self.evaluate:
+ neg_ctx_list = \
+ raw_sample['negative_context'][:self.val_av_rank_other_neg] + \
+ raw_sample['hard_negative_context'][:self.val_av_rank_hard_neg]
+ neg_ctx_id_list, neg_ctx_types_list = \
+ build_token_types_from_context_list(neg_ctx_list, \
+ self.tokenizer, self.max_seq_length)
+
+ elif self.train_with_neg:
+ hard_negative_ctx = raw_sample['hard_negative_context']
+ negative_ctx = raw_sample['negative_context']
+ if True: # TODO: fix this or remove this condition
+ random.shuffle(hard_negative_ctx)
+ random.shuffle(negative_ctx)
+
+ neg_ctx_list = hard_negative_ctx[:self.train_hard_neg]
+ # In the Google NQ dataset by DPR paper, there are around more than
+ # 50 missing hard negatives in training data.
+ # In those cases, substitute hard negatives by simple negatives.
+ if len(neg_ctx_list) < self.train_hard_neg:
+ neg_ctx_list += negative_ctx[:self.train_hard_neg - \
+ len(neg_ctx_list)]
+
+ neg_ctx_id_list, neg_ctx_types_list = \
+ build_token_types_from_context_list(neg_ctx_list,
+ self.tokenizer, self.max_seq_length)
+ else:
+ neg_ctx_id_list = None
+ neg_ctx_types_list = None
+
+ sample = build_sample(query_ids, query_types, query_pad_mask,
+ ctx_ids, ctx_types, ctx_pad_mask,
+ raw_sample['answers'],
+ neg_ctx_id_list, neg_ctx_types_list,
+ include_neg=self.evaluate or self.train_with_neg)
+
+ return sample
+
+ @staticmethod
+ @abstractmethod
+ def process_samples_from_single_path(filename):
+ """Abstract method that takes a filename and
+ returns a list of dataset samples, each sample being a dict of
+ {'text': string, 'text': string}
+ """
+ pass
+
+
+
+def normalize_question(question):
+ if question[-1] == '?':
+ question = question[:-1]
+ return question
+
+# The following class reads the datasets for training retriever as
+# prepared by the DPR codebase (https://github.com/facebookresearch/DPR)
+
+class NQSupervisedDataset(OpenRetrievalAbstractDataset):
+
+ def __init__(self, name, datapaths, tokenizer, max_seq_length, \
+ evaluate=False):
+ super().__init__('natural_questions_ret',
+ name,
+ datapaths,
+ tokenizer,
+ max_seq_length,
+ evaluate=evaluate)
+
+ @staticmethod
+ def process_samples_from_single_path(filename):
+ """"Implement abstract method."""
+ print_rank_0(' > Processing {} ...'.format(filename))
+ samples = []
+ total = 0
+
+ with open(filename, 'r', encoding="utf-8") as f:
+ data = json.load(f)
+ for row in data:
+ question = normalize_question(row['question'])
+ pos_context = row['positive_ctxs'][0]
+
+ # Hard Negative Contexts
+ if len(row['hard_negative_ctxs']) > 0:
+ hard_neg_context = row['hard_negative_ctxs']
+ else:
+ hard_neg_context = []
+
+ # Negative Contexts
+ if len(row['negative_ctxs']) > 0:
+ neg_context = row['negative_ctxs']
+ else:
+ neg_context = []
+
+ answers = row['answers']
+ sample = {'question': question,
+ 'pos_context': pos_context,
+ 'hard_negative_context': hard_neg_context,
+ 'negative_context': neg_context,
+ 'answers': answers}
+ total += 1
+ samples.append(sample)
+
+ if total % 5000 == 0:
+ print_rank_0(' > processed {} so far ...'.format(total))
+
+ print_rank_0(' >> processed {} samples.'.format(len(samples)))
+ return samples
+
diff --git a/tasks/orqa/supervised/eval_utils.py b/tasks/orqa/supervised/eval_utils.py
new file mode 100644
index 000000000..67dca512b
--- /dev/null
+++ b/tasks/orqa/supervised/eval_utils.py
@@ -0,0 +1,206 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Evaluation utilities."""
+from collections import OrderedDict
+import math
+import numpy as np
+import time
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+
+from megatron import get_args, print_rank_0
+from megatron import mpu
+from megatron.utils import average_losses_across_data_parallel_group
+from tasks.finetune_utils import build_data_loader
+
+def task_collate_fn(batch_data):
+ # generate batch
+ batch_size = len(batch_data)
+ tensorized = OrderedDict()
+ for d in batch_data:
+ for k, v in d.items():
+ tensorized.setdefault(k, []).append(v)
+
+ tensorized['query'] = torch.LongTensor(tensorized['query'])
+ tensorized['query_mask'] = torch.LongTensor(tensorized['query_mask'])
+ tensorized['query_types'] = torch.LongTensor(tensorized['query_types'])
+ tensorized['query_pad_mask'] = \
+ torch.LongTensor(tensorized['query_pad_mask'])
+
+ tensorized['context'] = torch.LongTensor(tensorized['context'])
+ tensorized['context_mask'] = \
+ torch.LongTensor(tensorized['context_mask'])
+ tensorized['context_types'] = \
+ torch.LongTensor(tensorized['context_types'])
+ tensorized['context_pad_mask'] = \
+ torch.LongTensor(tensorized['context_pad_mask'])
+
+ if 'neg_context' in tensorized:
+ tensorized['neg_context'] = \
+ torch.LongTensor(np.concatenate(tensorized['neg_context']))
+ tensorized['neg_context_mask'] = \
+ torch.LongTensor(np.concatenate(tensorized['neg_context_mask']))
+ tensorized['neg_context_types'] = \
+ torch.LongTensor(np.concatenate(tensorized['neg_context_types']))
+
+ return tensorized
+
+
+
+def process_batch(batch):
+ """Process batch and produce inputs for the model."""
+ query_tokens = batch['query'].long().cuda()
+ query_mask = (batch['query_mask'] < 0.5).cuda()
+ query_types = batch['query_types'].long().cuda()
+ query_pad_mask = batch['query_pad_mask'].long().cuda()
+
+ context_tokens = batch['context'].long().cuda()
+ context_mask = (batch['context_mask'] < 0.5).cuda()
+ context_types = batch['context_types'].long().cuda()
+ context_pad_mask = batch['context_pad_mask'].long().cuda()
+
+ if 'neg_context' in batch:
+ neg_context_tokens = batch['neg_context'].long().cuda()
+ neg_context_mask = (batch['neg_context_mask'] < 0.5).cuda()
+ neg_context_types = batch['neg_context_types'].long().cuda()
+ else:
+ neg_context_tokens = None
+ neg_context_mask = None
+ neg_context_types = None
+
+ reference = batch['reference']
+
+ return query_tokens, query_mask, query_types, query_pad_mask, \
+ context_tokens, context_mask, context_types, context_pad_mask, \
+ neg_context_tokens, neg_context_mask, neg_context_types, reference
+
+def accuracy_func_provider(single_dataset_provider, rank0sampler=False):
+ """Provide function that calculates accuracies."""
+ args = get_args()
+
+ print_rank_0("accuracy_func_provider is CALLED")
+
+ # Build dataloaders
+ datapath = args.valid_data
+ dataset = single_dataset_provider(datapath)
+
+ drop_last = False
+ if mpu.get_data_parallel_world_size() > 1 and not rank0sampler:
+ drop_last = True
+
+ print_rank_0(datapath)
+ print_rank_0(rank0sampler)
+
+ dataloader = build_data_loader(dataset,
+ args.eval_micro_batch_size,
+ num_workers=args.num_workers,
+ drop_last=drop_last,
+ task_collate_fn=task_collate_fn)
+ dataloaders = (dataset.dataset_name, dataloader)
+
+ def metrics_func(model, epoch, output_predictions=False):
+ print_rank_0('calculating metrics by accuracy func in ORQA...')
+
+ if output_predictions:
+ assert rank0sampler
+ names = 'predictions'
+ name, dataloader = dataloaders
+ if args.task == "RET-FINETUNE-NQ":
+ start_time = time.time()
+ output = retrieval_loss(model, dataloader)
+ stats_dict, total = output
+ format_string = ""
+ for k, v in stats_dict.items():
+ format_string += "|{} = {:.2f}".format(k, v / total)
+ print_rank_0("epoch:{}{}".format(epoch, format_string))
+ print_rank_0("taken time to calcuate metrics {:.3f}".format(\
+ time.time() - start_time))
+ else:
+ raise AssertionError("{} Task not supported".format(args.task))
+
+ return metrics_func
+
+
+def retrieval_loss(model, dataloader):
+ args = get_args()
+ total = 0
+ topk_stats_dict = {'top{}_acc'.format(k): 0 for k in \
+ args.retriever_report_topk_accuracies}
+ stats_dict = dict(rank=0, **topk_stats_dict)
+
+ assert len(model) == 1
+ unwrapped_model = model[0]
+ unwrapped_model.eval()
+
+ with torch.no_grad():
+ # For all the batches in the dataset.
+ for batch in dataloader:
+ # Run the model forward.
+ query_tokens, query_mask, query_types, _, \
+ context_tokens, context_mask, context_types, _, \
+ neg_context_tokens, neg_context_mask, neg_context_types, \
+ reference = process_batch(batch)
+
+ query_logits, context_logits = unwrapped_model(query_tokens,
+ query_mask, query_types,
+ torch.cat([context_tokens, neg_context_tokens]),
+ torch.cat([context_mask, neg_context_mask]),
+ torch.cat([context_types, neg_context_types]))
+
+ retrieval_scores = torch.matmul(query_logits,
+ torch.transpose(context_logits, 0, 1))
+
+ if args.retriever_score_scaling:
+ retrieval_scores = retrieval_scores / \
+ math.sqrt(args.hidden_size)
+
+ local_batch_size = query_logits.shape[0]
+ labels = torch.arange(local_batch_size).long().cuda()
+
+ softmax_scores = F.softmax(retrieval_scores, dim=1)
+ sorted_vals, sorted_indices = torch.topk(softmax_scores,
+ k=softmax_scores.shape[1],
+ sorted=True)
+
+ def topk_accuracy(k):
+ return torch.cuda.FloatTensor(
+ [sum([int(labels[i] in sorted_indices[i, :k]) for i in \
+ range(local_batch_size)])])
+
+ def get_rank():
+ return torch.cuda.FloatTensor(
+ [sum([torch.nonzero(labels[i] == sorted_indices[i])[0][0] \
+ for i in range(local_batch_size)])])
+
+ topk_accs = [topk_accuracy(k) for k in \
+ args.retriever_report_topk_accuracies]
+ rank = get_rank()
+ losses = average_losses_across_data_parallel_group([rank, \
+ *topk_accs])
+
+ # create stats_dict with retrieval loss and all specified
+ # top-k accuracies
+ topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
+ zip(args.retriever_report_topk_accuracies, losses[1:])}
+ temp_stats_dict = dict(rank=losses[0], **topk_acc_dict)
+ for k in stats_dict.keys():
+ stats_dict[k] += temp_stats_dict[k]
+ total += local_batch_size
+
+ unwrapped_model.train()
+
+ return stats_dict, total
diff --git a/tasks/orqa/supervised/finetune.py b/tasks/orqa/supervised/finetune.py
new file mode 100644
index 000000000..aed65ac97
--- /dev/null
+++ b/tasks/orqa/supervised/finetune.py
@@ -0,0 +1,251 @@
+# coding=utf-8
+# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""ORQA finetuning/evaluation."""
+
+from functools import partial
+import sys
+
+import math
+import torch
+import torch.nn.functional as F
+
+from megatron import get_args, get_timers, get_tokenizer
+from megatron import mpu, print_rank_0
+from megatron.indexer import IndexBuilder
+from megatron.model.biencoder_model import biencoder_model_provider
+from megatron.utils import average_losses_across_data_parallel_group
+from pretrain_ict import get_group_world_size_rank
+from tasks.finetune_utils import finetune
+from tasks.orqa.supervised.eval_utils import accuracy_func_provider
+from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn
+from tasks.orqa.evaluate_utils import ORQAEvaluator
+
+# input_ is a 2D tensor
+def check_and_append_tensor_for_gather(group, rank, world_size, input_):
+
+ # gather the size of the first dimension of the tensor from all ranks
+ current_length = input_.size()[0]
+ first_dim = torch.tensor([[current_length]],
+ device=torch.cuda.current_device())
+ input_list = [torch.empty_like(first_dim) for _ in range(world_size)]
+ input_list[rank].copy_(first_dim)
+ torch.distributed.all_gather(input_list, first_dim, group=group)
+ all_input_list = torch.cat(input_list, dim=0).contiguous()
+ max_length = torch.max(all_input_list)
+
+ # if the size are different than the max, extend the tensor
+ # accordingly
+ if max_length > current_length:
+ padding=tuple([0] * (input_.dim() * 2 - 1)) + \
+ tuple([max_length - current_length])
+ input_ = F.pad(input=input_, pad=padding)
+
+ return input_
+
+def orqa(Dataset):
+
+ def cross_entropy_forward_step(batch, model):
+ """Simple forward step with cross-entropy loss."""
+ timers = get_timers()
+ tokenizer = get_tokenizer()
+
+ # Get the batch.
+ timers('batch generator').start()
+ try:
+ batch_ = next(batch)
+ except BaseException:
+ batch_ = batch
+
+ group, rank, world_size = get_group_world_size_rank()
+
+ query_tokens, query_mask, query_types, query_pad_mask, \
+ context_tokens, context_mask, context_types, context_pad_mask, \
+ neg_context_tokens, neg_context_mask, neg_context_types, \
+ reference = process_batch(batch_)
+
+ timers('batch generator').stop()
+ local_batch_size = query_tokens.shape[0]
+
+ # Text representation of query and context
+ query_list, context_list = [], []
+ for i in range(local_batch_size):
+ query_list.append(tokenizer.decode(query_tokens[i].tolist()))
+ context_list.append(tokenizer.decode(context_tokens[i].tolist()))
+
+ if neg_context_tokens is not None:
+ neg_context_tokens = check_and_append_tensor_for_gather(group,
+ rank, world_size, neg_context_tokens)
+ neg_context_mask = check_and_append_tensor_for_gather(group,
+ rank, world_size, neg_context_mask)
+ neg_context_types = check_and_append_tensor_for_gather(group,
+ rank, world_size, neg_context_types)
+
+ if neg_context_tokens is not None:
+ context_tokens = torch.cat([context_tokens, neg_context_tokens])
+ context_mask = torch.cat([context_mask, neg_context_mask])
+ context_types = torch.cat([context_types, neg_context_types])
+
+ # Forward model.
+ output_tensor = model(query_tokens, query_mask,
+ query_types, context_tokens,
+ context_mask, context_types)
+ return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens)
+
+
+ def cross_entropy_loss_func(query_tokens, context_tokens, output_tensor):
+ args = get_args()
+
+ local_batch_size = query_tokens.shape[0]
+ group, rank, world_size = get_group_world_size_rank()
+ # recall we assert that model_parallel_size == 1
+ global_batch_size = world_size * local_batch_size
+
+ query_logits, context_logits = output_tensor
+
+ if world_size > 1:
+ input_ = torch.empty_like(context_logits).copy_(\
+ context_logits).detach_()
+ tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
+ tensor_list[rank].copy_(input_)
+ torch.distributed.all_gather(tensor_list, input_, group=group)
+
+ # Check if all-gather happens in order
+ assert tensor_list[rank].sum().item() == \
+ context_logits.sum().item()
+
+ # Preserves the gradient
+ tensor_list[rank] = context_logits
+ all_context_logits = torch.cat(tensor_list, dim=0).contiguous()
+
+ # Query tensors
+ input_ = torch.empty_like(query_logits).copy_(\
+ query_logits).detach_()
+ tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
+ tensor_list[rank].copy_(input_)
+ torch.distributed.all_gather(tensor_list, input_, group=group)
+
+ # Check if all-gather happens in order
+ assert tensor_list[rank].sum().item() == query_logits.sum().item()
+
+ # Preserves the gradient
+ tensor_list[rank] = query_logits
+ all_query_logits = torch.cat(tensor_list, dim=0).contiguous()
+ else:
+ all_query_logits = query_logits
+ all_context_logits = context_logits
+
+ retrieval_scores = torch.matmul(all_query_logits,
+ torch.transpose(all_context_logits, 0, 1))
+ # Scaling the retrieval scores
+ if args.retriever_score_scaling:
+ retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)
+
+ if args.train_with_neg:
+ # if the world size is 3, local batch size is 4, and
+ # local context size is 8, what we want is
+ # labels = [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19]
+ labels = []
+ local_context_size = context_tokens.shape[0]
+ for i in range(world_size):
+ j = i * local_context_size
+ labels.extend(list(range(j, j + local_batch_size)))
+ labels = torch.LongTensor(labels).cuda()
+ assert len(labels) == global_batch_size
+ else:
+ labels = torch.arange(global_batch_size).long().cuda()
+
+ # Cross-entropy loss.
+ softmax_scores = F.log_softmax(retrieval_scores, dim=1)
+
+ loss = F.nll_loss(softmax_scores, labels, reduction='mean')
+
+ max_score, max_idxs = torch.max(softmax_scores, 1)
+ correct_predictions_count = (max_idxs == labels).sum().float()
+
+ # Reduce loss for logging.
+ reduced_loss = average_losses_across_data_parallel_group([loss, \
+ correct_predictions_count])
+
+ # Loss scaling for correct losses in Supervised Retrieval
+ loss = loss * mpu.get_data_parallel_world_size()
+
+ return loss, {'lm loss': reduced_loss[0],
+ 'correct_prediction_count': reduced_loss[1]}
+
+
+ def train_valid_datasets_provider():
+ """Build train and validation dataset."""
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ train_dataset = Dataset('training',
+ args.train_data,
+ tokenizer,
+ args.retriever_seq_length,
+ evaluate=False)
+ valid_dataset = Dataset('validation',
+ args.valid_data,
+ tokenizer,
+ args.retriever_seq_length,
+ evaluate=True)
+ return train_dataset, valid_dataset
+
+ def model_provider(pre_process=True, post_process=True):
+ """Build the model."""
+ args = get_args()
+ print_rank_0('building retriever model for {} ...'.format(args.task))
+
+ model = biencoder_model_provider(only_context_model=False,
+ only_query_model=False,
+ biencoder_shared_query_context_model=\
+ args.biencoder_shared_query_context_model,
+ pre_process=pre_process, post_process=post_process)
+
+ return model
+
+ def single_dataset_provider(datapath):
+ args = get_args()
+ tokenizer = get_tokenizer()
+
+ name = datapath[0].split('/')[-1].split('.')[0]
+ return Dataset(name,
+ datapath,
+ tokenizer,
+ args.retriever_seq_length,
+ evaluate=True)
+
+ def metrics_func_provider():
+ """Provide metrics callback function."""
+ return accuracy_func_provider(single_dataset_provider)
+
+ """Finetune/evaluate."""
+ finetune(train_valid_datasets_provider,
+ model_provider,
+ forward_step=cross_entropy_forward_step,
+ end_of_epoch_callback_provider=metrics_func_provider,
+ task_collate_fn=task_collate_fn)
+
+def main():
+ args = get_args()
+
+ if args.task == 'RET-FINETUNE-NQ':
+ from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset
+ else:
+ raise NotImplementedError('ORQA task {} is not implemented.'.format(
+ args.task))
+
+ orqa(Dataset)
+
diff --git a/tasks/orqa/natural_questions/nq.py b/tasks/orqa/unsupervised/nq.py
similarity index 100%
rename from tasks/orqa/natural_questions/nq.py
rename to tasks/orqa/unsupervised/nq.py
diff --git a/tasks/orqa/natural_questions/qa_utils.py b/tasks/orqa/unsupervised/qa_utils.py
similarity index 98%
rename from tasks/orqa/natural_questions/qa_utils.py
rename to tasks/orqa/unsupervised/qa_utils.py
index 24e71e683..811a05834 100644
--- a/tasks/orqa/natural_questions/qa_utils.py
+++ b/tasks/orqa/unsupervised/qa_utils.py
@@ -22,7 +22,7 @@
from typing import Tuple, List, Dict
import regex as re
-from tasks.orqa.natural_questions.tokenizers import SimpleTokenizer
+from tasks.orqa.unsupervised.tokenizers import SimpleTokenizer
logger = logging.getLogger(__name__)
diff --git a/tasks/orqa/natural_questions/tokenizers.py b/tasks/orqa/unsupervised/tokenizers.py
similarity index 100%
rename from tasks/orqa/natural_questions/tokenizers.py
rename to tasks/orqa/unsupervised/tokenizers.py
diff --git a/tasks/vision/classification.py b/tasks/vision/classification.py
index 5232b3f54..71e840757 100644
--- a/tasks/vision/classification.py
+++ b/tasks/vision/classification.py
@@ -34,13 +34,14 @@ def train_valid_datasets_provider():
)
return train_ds, valid_ds
- def model_provider():
+ def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
print_rank_0("building classification model for ImageNet ...")
- return VitModel(num_classes=args.num_classes, finetune=True)
+ return VitModel(num_classes=args.num_classes, finetune=True,
+ pre_process=pre_process, post_process=post_process)
"""Finetune/evaluate."""
finetune(
diff --git a/tasks/vision/eval_utils.py b/tasks/vision/eval_utils.py
index aabc04a15..3a194119c 100644
--- a/tasks/vision/eval_utils.py
+++ b/tasks/vision/eval_utils.py
@@ -16,10 +16,14 @@
"""Evaluation utilities."""
import os
+from functools import partial
+
import torch
+
from megatron import get_args
-from megatron import print_rank_0
+from megatron import print_rank_0, print_rank_last
from megatron import mpu
+from megatron.schedules import get_forward_backward_func
from tasks.vision.finetune_utils import build_data_loader
from tasks.vision.finetune_utils import process_batch
from torchvision import datasets, transforms
@@ -56,7 +60,7 @@ def metrics_func(model, epoch):
print_rank_0("calculating metrics ...")
correct, total = calculate_correct_answers(model, dataloader, epoch)
percent = float(correct) * 100.0 / float(total)
- print_rank_0(
+ print_rank_last(
" >> |epoch: {}| overall: correct / total = {} / {} = "
"{:.4f} %".format(epoch, correct, total, percent)
)
@@ -67,29 +71,61 @@ def metrics_func(model, epoch):
def calculate_correct_answers(model, dataloader, epoch):
"""Calculate correct over total answers"""
- model.eval()
+ args = get_args()
+ forward_backward_func = get_forward_backward_func()
+ for m in model:
+ m.eval()
+
+ def loss_func(labels, output_tensor):
+ logits = output_tensor
+
+ loss_dict = {}
+ # Compute the correct answers.
+ predicted = torch.argmax(logits, dim=-1)
+ corrects = (predicted == labels).float()
+ # Add to the counters.
+ loss_dict['total'] = labels.size(0)
+ loss_dict['correct'] = corrects.sum().item()
+
+ return 0, loss_dict
+
+ #defined inside to capture output_predictions
+ def correct_answers_forward_step(batch, model):
+ try:
+ batch_ = next(batch)
+ except BaseException:
+ batch_ = batch
+ images, labels = process_batch(batch_)
+
+ # Forward model.
+ args = get_args()
+ output_tensor = model(images)
+
+ return output_tensor, partial(loss_func, labels)
+
with torch.no_grad():
# For all the batches in the dataset.
total = 0
correct = 0
for _, batch in enumerate(dataloader):
- # Run the model forward.
- images, labels = process_batch(batch)
- logits = model(images).contiguous().float()
- # Add output predictions.
- # Compute the correct answers.
- predicted = torch.argmax(logits, dim=-1)
- corrects = (predicted == labels).float()
- # Add to the counters.
- total += labels.size(0)
- correct += corrects.sum().item()
- model.train()
+
+ loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
+ optimizer=None, timers=None, forward_only=True)
+
+ for loss_dict in loss_dicts:
+ total += loss_dict['total']
+ correct += loss_dict['correct']
+
+ for m in model:
+ m.train()
# Reduce.
- unreduced = torch.cuda.LongTensor([correct, total])
- torch.distributed.all_reduce(unreduced, group=mpu.get_data_parallel_group())
+ if mpu.is_pipeline_last_stage():
+ unreduced = torch.cuda.LongTensor([correct, total])
+ torch.distributed.all_reduce(unreduced,
+ group=mpu.get_data_parallel_group())
- # Print on screen.
- correct_ans = unreduced[0].item()
- total_count = unreduced[1].item()
- return correct_ans, total_count
+ # Print on screen.
+ correct_ans = unreduced[0].item()
+ total_count = unreduced[1].item()
+ return correct_ans, total_count
diff --git a/tasks/vision/finetune_utils.py b/tasks/vision/finetune_utils.py
index afde4aa89..f9743883c 100644
--- a/tasks/vision/finetune_utils.py
+++ b/tasks/vision/finetune_utils.py
@@ -17,6 +17,7 @@
import torch
import torch.nn.functional as F
+from functools import partial
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
@@ -38,10 +39,21 @@ def process_batch(batch):
return images, labels
-def _cross_entropy_forward_step(batch, model, input_tensor):
+def cross_entropy_loss_func(labels, output_tensor):
+ logits = output_tensor
+
+ # Cross-entropy loss.
+ loss = F.cross_entropy(logits.contiguous().float(), labels)
+
+ # Reduce loss for logging.
+ averaged_loss = average_losses_across_data_parallel_group([loss])
+
+ return loss, {'lm loss': averaged_loss[0]}
+
+
+def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
- assert input_tensor is None
# Get the batch.
timers("batch generator").start()
@@ -52,16 +64,10 @@ def _cross_entropy_forward_step(batch, model, input_tensor):
images, labels = process_batch(batch_)
timers("batch generator").stop()
- # Forward model.
- logits = model(images).contiguous().float()
-
- # Cross-entropy loss.
- loss = F.cross_entropy(logits, labels)
-
- # Reduce loss for logging.
- average_loss = average_losses_across_data_parallel_group([loss])
-
- return loss, {"lm loss": average_loss[0]}
+ # Forward model.
+ output_tensor = model(images)
+
+ return output_tensor, partial(cross_entropy_loss_func, labels)
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
@@ -103,23 +109,28 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
"""Traing and validation dataloaders."""
args = get_args()
- print_rank_0("building train and validation dataloaders ...")
+ print_rank_0('building train and validation dataloaders ...')
# Training dataset.
- train_dataloader = build_data_loader(
- train_dataset, args.micro_batch_size, args.num_workers, not args.keep_last
- )
+ train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
+ args.num_workers, not args.keep_last)
# Set the training iterations.
args.train_iters_per_epoch = len(train_dataloader)
args.train_iters = args.epochs * args.train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop.
- valid_dataloader_ = build_data_loader(
- valid_dataset, args.micro_batch_size, args.num_workers, not args.keep_last
- )
+ valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
+ args.num_workers, not args.keep_last)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
- return train_dataloader, valid_dataloader
+ # Now that we've built the data loaders, set batch_size arguments
+ # to the actual batch size the model will see for this dataset.
+ # This is necessary so pipeline transfers know what size they are
+ # and the LR schedule, which is based on samples seen, gets set
+ # correctly.
+ args.orig_micro_batch_size = args.micro_batch_size
+ args.orig_global_batch_size = args.global_batch_size
+ return train_dataloader, valid_dataloader
def _train(
model,
@@ -135,7 +146,8 @@ def _train(
timers = get_timers()
# Turn on training mode which enables dropout.
- model.train()
+ for m in model:
+ m.train()
# Tracking loss.
losses_dict_sum = {}
@@ -166,12 +178,16 @@ def _train(
start_iteration = 0
# Train for one step.
- losses_dict, skipped_iter = train_step(
+ losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
forward_step, batch, model, optimizer, lr_scheduler
)
iteration += 1
# Logging.
+ params_norm = None
+ if args.log_params_norm:
+ params_norm = calc_params_l2_norm(model)
+
report_memory_flag = training_log(
losses_dict,
losses_dict_sum,
@@ -180,6 +196,9 @@ def _train(
optimizer.get_loss_scale().item(),
report_memory_flag,
skipped_iter,
+ grad_norm,
+ params_norm,
+ num_zeros_in_grad
)
# Autoresume
diff --git a/tools/create_doc_index.py b/tools/create_doc_index.py
deleted file mode 100644
index 4448d0e29..000000000
--- a/tools/create_doc_index.py
+++ /dev/null
@@ -1,33 +0,0 @@
-import os
-import sys
-sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
- os.path.pardir)))
-
-from megatron import print_rank_0
-from megatron.indexer import IndexBuilder
-from megatron.initialize import initialize_megatron
-
-
-def main():
- """Create a BlockData data structure by running an IndexBuilder over an ICT Dataset
- - Include all args needed for initial model specification
-
- Other key args:
- --block-data-path: path to write to
- --ict-load or --realm-load: path to checkpoint with which to embed
- --data-path and --titles-data-path: paths for dataset
- --indexer-log-interval: reporting interval
- --indexer-batch-size: size specific for indexer jobs
-
- Check README.md for example script
- """
-
- initialize_megatron(extra_args_provider=None,
- args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
- index_builder = IndexBuilder()
- index_builder.build_and_save_index()
- print_rank_0("Build and save indices: done!")
-
-if __name__ == "__main__":
- main()
-