diff --git a/README.md b/README.md index 127091f83..4b3e08ce2 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,11 @@ python tools/preprocess_data.py \ The output will be two files named, in this case, `my-bert_text_sentence.bin` and `my-bert_text_sentence.idx`. The `--data-path` specified in later BERT training is the full path and new filename, but without the file extension. +For T5 use the same preprocessing as BERT, perhaps renaming it to: +
+       --output-prefix my-t5 \
+
+ Some minor modifications are required for GPT data preprocessing, namely, the addition of a merge table, an end-of-document token, removal of sentence splitting, and a change to the tokenizer type:
 python tools/preprocess_data.py \
@@ -247,13 +252,14 @@ T5_ARGS="--num-layers 24 \
          --micro-batch-size 16 \
          --global-batch-size 2048 \
          --vocab-file $VOCAB_FILE \
+         --vocab-extra-ids 100 \
          --split 949,50,1 \
          --fp16"
 
 OUTPUT_ARGS=<same as those in BERT pretraining above>
 
 python pretrain_t5.py \
-       $BERT_ARGS \
+       $T5_ARGS \
        $OUTPUT_ARGS \
        --save $CHECKPOINT_PATH \
        --load $CHECKPOINT_PATH \
diff --git a/examples/create_embeddings.sh b/examples/create_embeddings.sh
deleted file mode 100644
index 59a5839f7..000000000
--- a/examples/create_embeddings.sh
+++ /dev/null
@@ -1,32 +0,0 @@
-#!/bin/bash
-
-# Compute embeddings for each entry of a given dataset (e.g. Wikipedia)
-
-RANK=0
-WORLD_SIZE=1
-
-# Wikipedia data can be downloaded from the following link:
-# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
-EVIDENCE_DATA_DIR=
-EMBEDDING_PATH=
-CHECKPOINT_PATH=
-
-python tools/create_doc_index.py \
-    --num-layers 12 \
-    --hidden-size 768 \
-    --num-attention-heads 12 \
-    --tensor-model-parallel-size 1 \
-    --micro-batch-size 128 \
-    --checkpoint-activations \
-    --seq-length 512 \
-    --retriever-seq-length 256 \
-    --max-position-embeddings 512 \
-    --load ${CHECKPOINT_PATH} \
-    --evidence-data-path ${EVIDENCE_DATA_DIR} \
-    --embedding-path ${EMBEDDING_PATH} \
-    --indexer-log-interval 1000 \
-    --indexer-batch-size 128 \
-    --vocab-file bert-vocab.txt \
-    --num-workers 2 \
-    --fp16
-
diff --git a/examples/evaluate_ict_zeroshot_nq.sh b/examples/evaluate_retriever_nq.sh
similarity index 75%
rename from examples/evaluate_ict_zeroshot_nq.sh
rename to examples/evaluate_retriever_nq.sh
index e1ce45a93..8b87be302 100644
--- a/examples/evaluate_ict_zeroshot_nq.sh
+++ b/examples/evaluate_retriever_nq.sh
@@ -1,19 +1,19 @@
 #!/bin/bash
 
 # Evaluate natural question test data given Wikipedia embeddings and pretrained
-# ICT model
+# ICT model or a finetuned model for Natural Question task
 
 # Datasets can be downloaded from the following link:
 # https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
 
 EVIDENCE_DATA_DIR=
 EMBEDDING_PATH=
-CHECKPOINT_PATH=
+CHECKPOINT_PATH=
 
-QA_FILE=
+QA_FILE=
 
 python tasks/main.py \
-    --task ICT-ZEROSHOT-NQ \
+    --task RETRIEVER-EVAL \
     --tokenizer-type BertWordPieceLowerCase \
     --num-layers 12 \
     --hidden-size 768 \
@@ -29,8 +29,10 @@ python tasks/main.py \
     --retriever-seq-length 256 \
     --vocab-file  bert-vocab.txt\
     --qa-data-test ${QA_FILE} \
-    --num-workers 2 \
     --faiss-use-gpu \
     --retriever-report-topk-accuracies 1 5 20 100 \
-    --fp16
+    --fp16 \
+    --indexer-log-interval 1000 \
+    --indexer-batch-size 128
+
 
diff --git a/examples/finetune_retriever_distributed.sh b/examples/finetune_retriever_distributed.sh
new file mode 100755
index 000000000..535a2e053
--- /dev/null
+++ b/examples/finetune_retriever_distributed.sh
@@ -0,0 +1,56 @@
+#!/bin/bash
+
+# Finetune a BERT or pretrained ICT model using Google natural question data 
+# Datasets can be downloaded from the following link:
+# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
+
+WORLD_SIZE=8
+
+DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
+                  --nnodes 1 \
+                  --node_rank 0 \
+                  --master_addr localhost \
+                  --master_port 6000"
+
+CHECKPOINT_PATH=
+
+# Load either of the below
+BERT_LOAD_PATH=
+PRETRAINED_CHECKPOINT=
+
+python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
+        --task RET-FINETUNE-NQ \
+        --train-with-neg \
+        --train-hard-neg 1 \
+        --pretrained-checkpoint ${PRETRAINED_CHECKPOINT} \
+        --num-layers 12 \
+        --hidden-size 768 \
+        --num-attention-heads 12 \
+        --tensor-model-parallel-size 1 \
+        --tokenizer-type BertWordPieceLowerCase \
+        --train-data nq-train.json \
+        --valid-data nq-dev.json \
+        --save ${CHECKPOINT_PATH} \
+        --load ${CHECKPOINT_PATH} \
+        --vocab-file bert-vocab.txt \
+        --bert-load ${BERT_LOAD_PATH} \
+        --save-interval 5000 \
+        --log-interval 10 \
+        --eval-interval 20000 \
+        --eval-iters 100 \
+        --indexer-log-interval 1000 \
+        --faiss-use-gpu \
+        --DDP-impl torch \
+        --fp16 \
+        --retriever-report-topk-accuracies 1 5 10 20 100 \
+        --seq-length 512 \
+        --retriever-seq-length 256 \
+        --max-position-embeddings 512 \
+        --retriever-score-scaling \
+        --epochs 80 \
+        --micro-batch-size 8 \
+        --eval-micro-batch-size 16 \
+        --indexer-batch-size 128 \
+        --lr 2e-5 \
+        --lr-warmup-fraction 0.01 \
+        --weight-decay 1e-1
diff --git a/examples/pretrain_t5.sh b/examples/pretrain_t5.sh
index 71fea8489..91fd5929b 100644
--- a/examples/pretrain_t5.sh
+++ b/examples/pretrain_t5.sh
@@ -15,7 +15,7 @@ python pretrain_t5.py \
        --encoder-seq-length 512 \
        --decoder-seq-length 128 \
        --micro-batch-size 16 \
-       --global-batch-size 2048 \
+       --global-batch-size 16 \
        --max-position-embeddings 512 \
        --train-iters 1000000 \
        --lr-decay-iters 1000000 \
@@ -35,4 +35,5 @@ python pretrain_t5.py \
        --save-interval 10000 \
        --eval-interval 1000 \
        --eval-iters 10 \
-       --fp16
+       --fp16 \
+       --vocab-extra-ids 100
diff --git a/examples/pretrain_t5_distributed.sh b/examples/pretrain_t5_distributed.sh
index 778b4ad2a..2beb1cdac 100644
--- a/examples/pretrain_t5_distributed.sh
+++ b/examples/pretrain_t5_distributed.sh
@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
        --encoder-seq-length 512 \
        --decoder-seq-length 128 \
        --micro-batch-size 16 \
-       --global-batch-size 2048 \
+       --global-batch-size 128 \
        --max-position-embeddings 512 \
        --train-iters 1000000 \
        --lr-decay-iters 1000000 \
@@ -44,4 +44,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
        --save-interval 10000 \
        --eval-interval 1000 \
        --eval-iters 10 \
-       --fp16
+       --fp16 \
+       --vocab-extra-ids 100
diff --git a/examples/pretrain_t5_distributed_with_mp.sh b/examples/pretrain_t5_distributed_with_mp.sh
index 9be70393d..23f1cd664 100644
--- a/examples/pretrain_t5_distributed_with_mp.sh
+++ b/examples/pretrain_t5_distributed_with_mp.sh
@@ -24,8 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
        --encoder-seq-length 512 \
        --decoder-seq-length 128 \
        --micro-batch-size 16 \
-       --global-batch-size 2048 \
-       --seq-length 512 \
+       --global-batch-size 128 \
        --max-position-embeddings 512 \
        --train-iters 1000000 \
        --lr-decay-iters 1000000 \
@@ -45,4 +44,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
        --save-interval 10000 \
        --eval-interval 1000 \
        --eval-iters 10 \
-       --fp16
+       --fp16  \
+       --vocab-extra-ids 100
diff --git a/megatron/arguments.py b/megatron/arguments.py
index 0b21ffbdb..49f8b253a 100644
--- a/megatron/arguments.py
+++ b/megatron/arguments.py
@@ -440,6 +440,8 @@ def _add_training_args(parser):
                        help='Run optimizer on CPU')
     group.add_argument('--cpu_torch_adam', action='store_true',
                        help='Use Torch Adam as optimizer on CPU.')
+    group.add_argument('--codecarbon-dir', type=str, default=None,
+                       help='Write CodeCarbon logs to this directory.')
 
     return parser
 
diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py
index 3cc6a8e2e..49c4b6533 100644
--- a/megatron/checkpointing.py
+++ b/megatron/checkpointing.py
@@ -60,8 +60,8 @@ def _compare(arg_name, old_arg_name=None):
     _compare('num_layers')
     _compare('hidden_size')
     _compare('num_attention_heads')
-    _compare('max_position_embeddings')
     if args.vocab_file:
+        _compare('max_position_embeddings')
         _compare('make_vocab_size_divisible_by')
         _compare('padded_vocab_size')
         _compare('tokenizer_type')
diff --git a/megatron/indexer.py b/megatron/indexer.py
index c0d1ca7de..d2ff9e36f 100644
--- a/megatron/indexer.py
+++ b/megatron/indexer.py
@@ -1,15 +1,16 @@
 import sys
+import time
 import torch
 import torch.distributed as dist
 
-from megatron import get_args
+from megatron import get_args, print_rank_0
 from megatron import mpu
 from megatron.checkpointing import load_biencoder_checkpoint
 from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
 from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
 from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader
 from megatron.data.realm_index import detach, OpenRetreivalDataStore
-from megatron.model.biencoder_model import biencoder_model_provider
+from megatron.model.biencoder_model import get_model_provider
 from megatron.training import get_model
 
 
@@ -29,7 +30,6 @@ def __init__(self):
         # need to know whether we're using a REALM checkpoint (args.load)
         # or ICT checkpoint
         assert not (args.load and args.ict_load)
-        #self.using_realm_chkpt = args.ict_load is None
 
         self.log_interval = args.indexer_log_interval
         self.batch_size = args.indexer_batch_size
@@ -47,8 +47,8 @@ def load_attributes(self):
         if self.biencoder_shared_query_context_model:
             only_context_model = False
 
-        model = get_model(lambda: biencoder_model_provider(only_context_model \
-            = only_context_model, biencoder_shared_query_context_model = \
+        model = get_model(get_model_provider(only_context_model=\
+            only_context_model, biencoder_shared_query_context_model=\
             self.biencoder_shared_query_context_model))
 
         self.model = load_biencoder_checkpoint(model,
@@ -85,6 +85,7 @@ def build_and_save_index(self):
         """
         assert len(self.model) == 1
         unwrapped_model = self.model[0]
+
         while not hasattr(unwrapped_model, 'embed_text'):
             unwrapped_model = unwrapped_model.module
 
@@ -103,6 +104,7 @@ def build_and_save_index(self):
             context_logits = unwrapped_model.embed_text(
                 unwrapped_model.context_model, context_tokens, context_mask,
                 context_types)
+
             context_logits = detach(context_logits)
             row_id = detach(row_id)
 
diff --git a/megatron/learning_rates.py b/megatron/learning_rates.py
index d200bdb17..c53af8d54 100644
--- a/megatron/learning_rates.py
+++ b/megatron/learning_rates.py
@@ -87,7 +87,7 @@ def get_lr(self):
         else:
             raise Exception('{} decay style is not supported.'.format(
                 self.decay_style))
-       
+
         return self.min_lr + coeff * delta_lr
 
 
diff --git a/megatron/model/biencoder_model.py b/megatron/model/biencoder_model.py
index 51ac0a060..e1f94bf1c 100644
--- a/megatron/model/biencoder_model.py
+++ b/megatron/model/biencoder_model.py
@@ -15,11 +15,30 @@
 from megatron.model.utils import scaled_init_method_normal
 from .module import MegatronModule
 
+def get_model_provider(only_query_model=False, only_context_model=False,
+        biencoder_shared_query_context_model=False):
+
+    def model_provider(pre_process=True, post_process=True):
+        """Build the model."""
+
+        print_rank_0('building Bienoder model ...')
+        model = biencoder_model_provider(only_query_model=only_query_model,
+                only_context_model = only_context_model,
+                biencoder_shared_query_context_model = \
+                biencoder_shared_query_context_model,
+                pre_process=pre_process, post_process=post_process)
+
+        return model
+
+    return model_provider
+
+
 def biencoder_model_provider(only_query_model=False,
                              only_context_model=False,
-                             biencoder_shared_query_context_model=False):
+                             biencoder_shared_query_context_model=False,
+                             pre_process=True,
+                             post_process=True):
     """Build the model."""
-    args = get_args()
 
     assert mpu.get_tensor_model_parallel_world_size() == 1 and \
         mpu.get_pipeline_model_parallel_world_size() == 1, \
@@ -35,7 +54,9 @@ def biencoder_model_provider(only_query_model=False,
         only_query_model=only_query_model,
         only_context_model=only_context_model,
         biencoder_shared_query_context_model=\
-            biencoder_shared_query_context_model)
+        biencoder_shared_query_context_model,
+        pre_process=pre_process,
+        post_process=post_process)
 
     return model
 
@@ -48,13 +69,17 @@ def __init__(self,
                  parallel_output=True,
                  only_query_model=False,
                  only_context_model=False,
-                 biencoder_shared_query_context_model=False):
+                 biencoder_shared_query_context_model=False,
+                 pre_process=True,
+                 post_process=True):
         super(BiEncoderModel, self).__init__()
         args = get_args()
 
         bert_kwargs = dict(
             num_tokentypes=num_tokentypes,
-            parallel_output=parallel_output)
+            parallel_output=parallel_output,
+            pre_process=pre_process,
+            post_process=post_process)
 
         self.biencoder_shared_query_context_model = \
             biencoder_shared_query_context_model
@@ -78,6 +103,13 @@ def __init__(self,
                 self.context_model = PretrainedBertModel(**bert_kwargs)
                 self._context_key = 'context_model'
 
+    def set_input_tensor(self, input_tensor):
+        """See megatron.model.transformer.set_input_tensor()"""
+        # this is just a placeholder and will be needed when model
+        # parallelism will be used
+        # self.language_model.set_input_tensor(input_tensor)
+        return
+
     def forward(self, query_tokens, query_attention_mask, query_types,
                 context_tokens, context_attention_mask, context_types):
         """Run a forward pass for each of the models and
@@ -217,7 +249,7 @@ class PretrainedBertModel(MegatronModule):
     learned information retrieval."""
 
     def __init__(self, num_tokentypes=2,
-            parallel_output=True):
+            parallel_output=True, pre_process=True, post_process=True):
         super(PretrainedBertModel, self).__init__()
 
         args = get_args()
@@ -225,6 +257,8 @@ def __init__(self, num_tokentypes=2,
         self.pad_id = tokenizer.pad
         self.biencoder_projection_dim = args.biencoder_projection_dim
         self.parallel_output = parallel_output
+        self.pre_process = pre_process
+        self.post_process = post_process
         init_method = init_method_normal(args.init_method_std)
         scaled_init_method = scaled_init_method_normal(
             args.init_method_std, args.num_layers)
@@ -234,7 +268,9 @@ def __init__(self, num_tokentypes=2,
             add_pooler=False,
             encoder_attn_mask_type=AttnMaskType.padding,
             init_method=init_method,
-            scaled_init_method=scaled_init_method)
+            scaled_init_method=scaled_init_method,
+            pre_process=self.pre_process,
+            post_process=self.post_process)
 
         if args.biencoder_projection_dim > 0:
             self.projection_enc = get_linear_layer(args.hidden_size,
@@ -247,7 +283,6 @@ def forward(self, input_ids, attention_mask, tokentype_ids=None):
         #extended_attention_mask = bert_extended_attention_mask(attention_mask)
         position_ids = bert_position_ids(input_ids)
 
-
         lm_output = self.language_model(input_ids,
                                         position_ids,
                                         extended_attention_mask,
@@ -285,7 +320,7 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
 
     def load_state_dict(self, state_dict, strict=True):
         """Customized load."""
-        print_rank_0("loading BERT weights")
+        print_rank_0("loading pretrained weights")
         self.language_model.load_state_dict(
             state_dict[self._language_model_key], strict=strict)
 
diff --git a/megatron/model/vit_model.py b/megatron/model/vit_model.py
index 84a52a829..a1a86cfff 100644
--- a/megatron/model/vit_model.py
+++ b/megatron/model/vit_model.py
@@ -50,11 +50,11 @@ def __init__(self, hidden_size, num_classes):
     def forward(self, hidden_states, sequence_index=0):
         # hidden_states: [b, s, h]
         # sequence_index: index of the token to pool.
-        x = hidden_states[:, sequence_index, :]
-        x = self.dense_in(x)
-        x = torch.tanh(x)
-        x = self.dense_out(x)
-        return x
+        hidden_state = hidden_states[:, sequence_index, :]
+        dense_in_result = self.dense_in(hidden_state)
+        tanh_result = torch.tanh(dense_in_result)
+        dense_out_result = self.dense_out(tanh_result)
+        return dense_out_result
 
 
 def twod_interpolate_position_embeddings_hook(
@@ -122,8 +122,12 @@ def twod_interpolate_position_embeddings_hook(
 class VitModel(MegatronModule):
     """Vision Transformer Model."""
 
-    def __init__(self, num_classes, finetune=False):
-        super(VitModel, self).__init__()
+    def __init__(self, 
+                 num_classes,
+                 finetune=False,
+                 pre_process=True,
+                 post_process=True):
+        super(VitModel, self).__init__(share_word_embeddings=False)
         args = get_args()
 
         self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
@@ -136,6 +140,8 @@ def __init__(self, num_classes, finetune=False):
                 args.init_method_std, args.num_layers
             )
 
+        self.pre_process = pre_process
+        self.post_process = post_process
         self.hidden_size = args.hidden_size
         self.num_classes = num_classes
         self.patch_dim = args.patch_dim
@@ -148,63 +154,81 @@ def __init__(self, num_classes, finetune=False):
         self.seq_length = self.num_patches + 1
         self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels
 
-        # cls_token
-        self.cls_token = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size))
-        torch.nn.init.zeros_(self.cls_token)
+        if self.pre_process:
+            # cls_token
+            self.cls_token = torch.nn.Parameter(
+                torch.randn(1, 1, self.hidden_size)
+            )
+            torch.nn.init.zeros_(self.cls_token)
 
-        # Linear encoder
-        self.linear_encoder = torch.nn.Linear(
-            self.flatten_dim, self.hidden_size
-        )
+            # Linear encoder
+            self.linear_encoder = torch.nn.Linear(
+                self.flatten_dim, self.hidden_size
+            )
 
-        # embedding
-        self.position_embeddings = torch.nn.Embedding(
-            self.seq_length, self.hidden_size
-        )
-        init_method_normal(args.init_method_std)(
-            self.position_embeddings.weight
-        )
-        self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
+            # embedding
+            self.position_embeddings = torch.nn.Embedding(
+                self.seq_length, self.hidden_size
+            )
+            init_method_normal(args.init_method_std)(
+                self.position_embeddings.weight
+            )
+            self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
 
-        self.position_embeddings._register_load_state_dict_pre_hook(
-            twod_interpolate_position_embeddings_hook
-        )
+            self.position_embeddings._register_load_state_dict_pre_hook(
+                twod_interpolate_position_embeddings_hook
+            )
 
-        self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
+            self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
 
         # Transformer
         self.transformer = ParallelTransformer(
-            self.init_method, self.scaled_init_method
+            self.init_method, 
+            self.scaled_init_method,
+            pre_process=self.pre_process,
+            post_process=self.post_process
         )
 
-        # MLP head
-        if not self.finetune:
-            self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
-        else:
-            self.class_head = get_linear_layer(
-                self.hidden_size, num_classes, torch.nn.init.zeros_
+        if self.post_process:
+            # MLP head
+            if not self.finetune:
+                self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
+            else:
+                self.class_head = get_linear_layer(
+                    self.hidden_size, num_classes, torch.nn.init.zeros_
+                )
+
+    def set_input_tensor(self, input_tensor):
+        """See megatron.model.transformer.set_input_tensor()"""
+        self.transformer.set_input_tensor(input_tensor)
+
+    def forward(self, input):
+
+        if self.pre_process:
+            rearranged_input = einops.rearrange(
+                input,
+                "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
+                p1=self.patch_dim,
+                p2=self.patch_dim,
             )
 
-    def forward(self, x):
-        x = einops.rearrange(
-            x,
-            "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
-            p1=self.patch_dim,
-            p2=self.patch_dim,
-        )
+            assert rearranged_input.dtype == torch.half
+            encoder_output = self.linear_encoder(rearranged_input)
+            cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1)
+            concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1)
 
-        assert x.dtype == torch.half
-        x = self.linear_encoder(x)
-        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
-        x = torch.cat((cls_tokens, x), dim=1)
+            token_embeddings = concatenated_tokens + \
+                self.position_embeddings(self.position_ids)
+            hidden_states = self.embedding_dropout(token_embeddings)
+        else:
+            hidden_states = input
 
-        x = x + self.position_embeddings(self.position_ids)
-        x = self.embedding_dropout(x)
-        x = self.transformer(x, None)
+        hidden_states = self.transformer(hidden_states, None)
 
-        if not self.finetune:
-            x = self.mlp_head(x)
-        else:
-            x = self.class_head(x[:, 0, :])
+        if self.post_process:
+            if not self.finetune:
+                hidden_states = self.mlp_head(hidden_states)
+            else:
+                hidden_states = self.class_head(hidden_states[:, 0, :])
 
-        return x
+        return hidden_states
diff --git a/megatron/tokenizer/bert_tokenization.py b/megatron/tokenizer/bert_tokenization.py
index a3aa6d907..99f9a8795 100644
--- a/megatron/tokenizer/bert_tokenization.py
+++ b/megatron/tokenizer/bert_tokenization.py
@@ -181,6 +181,35 @@ def convert_tokens_to_ids(self, tokens):
     def convert_ids_to_tokens(self, ids):
         return convert_by_vocab(self.inv_vocab, ids)
 
+    @staticmethod
+    def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
+        """ Converts a sequence of tokens (string) in a single string. """
+
+        def clean_up_tokenization(out_string):
+            """ Clean up a list of simple English tokenization artifacts
+            like spaces before punctuations and abreviated forms.
+            """
+            out_string = (
+                out_string.replace(" .", ".")
+                    .replace(" ?", "?")
+                    .replace(" !", "!")
+                    .replace(" ,", ",")
+                    .replace(" ' ", "'")
+                    .replace(" n't", "n't")
+                    .replace(" 'm", "'m")
+                    .replace(" 's", "'s")
+                    .replace(" 've", "'ve")
+                    .replace(" 're", "'re")
+            )
+            return out_string
+
+        text = ' '.join(tokens).replace(' ##', '').strip()
+        if clean_up_tokenization_spaces:
+            clean_text = clean_up_tokenization(text)
+            return clean_text
+        else:
+            return text
+
     def vocab_size(self):
         return len(self.vocab)
 
diff --git a/megatron/training.py b/megatron/training.py
index 5adef3c11..0cdb0d1bc 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -110,6 +110,17 @@ def pretrain(train_valid_test_dataset_provider,
     args = get_args()
     timers = get_timers()
 
+    # XXX: quick hack-in for now - add a clean wrapper later
+    if args.codecarbon_dir is not None:
+        import codecarbon
+        from pathlib import Path
+        print("CC START")
+
+        Path(args.codecarbon_dir).mkdir(parents=True, exist_ok=True)
+        output_file = f"emissions-{args.rank:03d}.csv"
+        cc_tracker = codecarbon.EmissionsTracker(output_dir=args.codecarbon_dir, output_file=output_file)
+        cc_tracker.start()
+
     # Model, optimizer, and learning rate.
     timers('model-and-optimizer-setup').start()
     model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
@@ -162,6 +173,12 @@ def pretrain(train_valid_test_dataset_provider,
                                    test_data_iterator, model,
                                    0, True)
 
+
+    # XXX: clean up
+    if args.codecarbon_dir is not None:
+        print("CC STOP")
+        cc_tracker.stop()
+
 def update_train_iters(args):
 
     # For iteration-based training, we don't need to do anything
@@ -838,7 +855,10 @@ def build_train_valid_test_data_iterators(
         assert args.train_samples is None, \
             'only backward compatiblity support for iteration-based training'
         args.consumed_train_samples = args.iteration * args.global_batch_size
-    if args.iteration > 0 and args.consumed_valid_samples == 0:
+    # it's possible that train was run, but not eval and it's valid if
+    # args.consumed_valid_samples == 0
+    # TODO: eval_interval could have changed between runs, so this might still be wrong
+    if args.iteration // args.eval_interval > 0:
         assert args.train_samples is None, \
             'only backward compatiblity support for iteration-based training'
         args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
diff --git a/pretrain_ict.py b/pretrain_ict.py
index 1438b3d57..79759250f 100644
--- a/pretrain_ict.py
+++ b/pretrain_ict.py
@@ -14,6 +14,8 @@
 # limitations under the License.
 
 """Pretrain BERT for Inverse Cloze Task"""
+
+from functools import partial
 import math
 
 import torch
@@ -31,13 +33,16 @@
 from megatron.utils import average_losses_across_data_parallel_group
 
 
-def pretrain_ict_model_provider():
+def pretrain_ict_model_provider(pre_process=True, post_process=True):
     args = get_args()
+
     model = biencoder_model_provider(
                 only_context_model=False,
                 only_query_model=False,
                 biencoder_shared_query_context_model=\
-                    args.biencoder_shared_query_context_model)
+                args.biencoder_shared_query_context_model,
+                pre_process=pre_process, post_process=post_process)
+
     return model
 
 def get_group_world_size_rank():
@@ -77,25 +82,9 @@ def backward(ctx, grad_output):
         output = output_list[rank].contiguous()
         return output
 
-def forward_step(data_iterator, model, input_tensor):
-    """Forward step."""
+def loss_func(output_tensor):
     args = get_args()
-    timers = get_timers()
-
-    # Get the batch.
-    timers('batch-generator').start()
-    query_tokens, query_mask, \
-    context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
-    timers('batch-generator').stop()
-
-    # Query and Context Types
-    query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
-    context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
-
-    # Forward model.
-    query_logits, context_logits = model(query_tokens, query_mask,
-                                    query_types, context_tokens,
-                                    context_mask, context_types)
+    query_logits, context_logits = output_tensor
 
     micro_batch_size = query_logits.shape[0]
     # recall we assert that tensor_model_parallel_size == 1
@@ -137,6 +126,28 @@ def topk_accuracy(k):
     return loss, stats_dict
 
 
+
+def forward_step(data_iterator, model):
+    """Forward step."""
+    args = get_args()
+    timers = get_timers()
+
+    # Get the batch.
+    timers('batch-generator').start()
+    query_tokens, query_mask, \
+    context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
+    timers('batch-generator').stop()
+
+    # Query and Context Types
+    query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
+    context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
+
+    # Forward model.
+    output_tensor = model(query_tokens, query_mask, query_types, context_tokens,
+                        context_mask, context_types)
+
+    return output_tensor, partial(loss_func)
+
 def train_valid_test_datasets_provider(train_val_test_num_samples):
     """Build train, valid and test datasets."""
     args = get_args()
diff --git a/pretrain_vit.py b/pretrain_vit.py
index 16ec10439..7770c68d5 100644
--- a/pretrain_vit.py
+++ b/pretrain_vit.py
@@ -17,19 +17,22 @@
 
 import torch
 import torch.nn.functional as F
+from functools import partial
 from megatron import get_args, get_timers, mpu, print_rank_0
 from megatron.data.vit_dataset import build_train_valid_datasets
 from megatron.model.vit_model import VitModel
 from megatron.training import pretrain
 from megatron.utils import average_losses_across_data_parallel_group
 
-def model_provider():
+def model_provider(pre_process=True, post_process=True):
     """Build the model."""
 
     print_rank_0("building VIT model ...")
     args = get_args()
 
-    model = VitModel(num_classes=args.num_classes)
+    model = VitModel(num_classes=args.num_classes,
+                     pre_process=pre_process,
+                     post_process=post_process)
     return model
 
 def get_batch(data_iterator):
@@ -42,10 +45,21 @@ def get_batch(data_iterator):
 
     return images, labels
 
-def forward_step(data_iterator, model, input_tensor):
+def loss_func(labels, output_tensor):
+    logits = output_tensor.contiguous().float()
+    loss = F.cross_entropy(logits, labels)
+
+    outputs = torch.argmax(logits, -1)
+    correct = (outputs == labels).float()
+    accuracy = torch.mean(correct)
+
+    averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])
+
+    return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
+
+def forward_step(data_iterator, model):
     """Forward step."""
     timers = get_timers()
-    assert input_tensor is None
 
     # Get the batch.
     timers("batch-generator").start()
@@ -56,17 +70,9 @@ def forward_step(data_iterator, model, input_tensor):
     timers("batch-generator").stop()
 
     # Forward model. lm_labels
-    logits = model(images).contiguous().float()
-    loss = F.cross_entropy(logits, labels)
-
-    outputs = torch.argmax(logits, -1)
-    correct = (outputs == labels).float()
-    accuracy = torch.mean(correct)
-
-    averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])
-
-    return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
+    output_tensor = model(images)
 
+    return output_tensor, partial(loss_func, labels)
 
 def train_valid_test_datasets_provider(train_val_test_num_samples):
     """Build train, valid, and test datasets."""
diff --git a/requirements.txt b/requirements.txt
index 1f7389c3e..cd79071d0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,11 @@
+# megatron requirements
+
 pybind11
 torch
 six
 regex
 numpy
+
+# big-science requirements
+
+codecarbon==1.2.0
diff --git a/tasks/finetune_utils.py b/tasks/finetune_utils.py
index 918417b41..9411b1849 100644
--- a/tasks/finetune_utils.py
+++ b/tasks/finetune_utils.py
@@ -16,10 +16,10 @@
 """Finetune utilities."""
 
 from functools import partial
-
+import sys
 import torch
 
-from megatron import get_args
+from megatron import get_args, get_num_microbatches
 from megatron import print_rank_0
 from megatron import get_timers
 from megatron import mpu
@@ -80,7 +80,8 @@ def _cross_entropy_forward_step(batch, model):
     return output_tensor, partial(cross_entropy_loss_func, labels)
 
 
-def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
+def build_data_loader(dataset, micro_batch_size, num_workers, drop_last,
+        task_collate_fn=None):
     """Data loader. Note that batch-size is the local (per GPU) batch-size."""
 
     # Sampler.
@@ -96,7 +97,8 @@ def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
                                               shuffle=False,
                                               num_workers=num_workers,
                                               drop_last=drop_last,
-                                              pin_memory=True)
+                                              pin_memory=True,
+                                              collate_fn=task_collate_fn)
 
     return data_loader
 
@@ -112,21 +114,24 @@ def _build_infinite_size_dataloader(dataloader):
             iterator = dataloader.__iter__()
 
 
-def _build_train_valid_dataloaders(train_dataset, valid_dataset):
+def _build_train_valid_dataloaders(train_dataset, valid_dataset, 
+    task_collate_fn=None):
     """Traing and validation dataloaders."""
     args = get_args()
 
     print_rank_0('building train and validation dataloaders ...')
     # Training dataset.
     train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
-                                         args.num_workers, not args.keep_last)
+                                         args.num_workers, not args.keep_last,
+                                         task_collate_fn)
     # Set the training iterations.
     args.train_iters_per_epoch = len(train_dataloader)
     args.train_iters = args.epochs * args.train_iters_per_epoch
     # Validation dataset. For this dataset, we do not need to set up
     # shuffling so we can just use a simple infinite loop.
     valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
-                                          args.num_workers, not args.keep_last)
+                                          args.num_workers, not args.keep_last,
+                                          task_collate_fn)
     valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
 
     # Now that we've built the data loaders, set batch_size arguments
@@ -154,6 +159,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
     args = get_args()
     timers = get_timers()
 
+    assert get_num_microbatches() == 1, "finetuning with gradient accumulation doesn't currently work"
+
     # Turn on training mode which enables dropout.
     for m in model:
         m.train()
@@ -188,6 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
 
             # Train for one step.
             out = train_step(forward_step, batch, model, optimizer, lr_scheduler)
+
             losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
             iteration += 1
 
@@ -209,9 +217,11 @@ def _train(model, optimizer, lr_scheduler, forward_step,
                                                   optimizer, lr_scheduler)
 
             # Checkpointing
+            saved_checkpoint = False
             if args.save and args.save_interval and \
                iteration % args.save_interval == 0:
                 save_checkpoint(iteration, model, optimizer, lr_scheduler)
+                saved_checkpoint = True
 
             # Evaluation
             if args.eval_interval and iteration % args.eval_interval == 0:
@@ -220,6 +230,14 @@ def _train(model, optimizer, lr_scheduler, forward_step,
                                            valid_dataloader, model,
                                            iteration, False)
 
+            # Exiting based on iterations
+            if args.exit_interval and iteration % args.exit_interval == 0:
+                if not saved_checkpoint:
+                    save_checkpoint(iteration, model, optimizer, lr_scheduler)
+                torch.distributed.barrier()
+                print_rank_0('exiting program at iteration {}'.format(iteration))
+                sys.exit()
+
         # Checkpointing at the end of each epoch.
         if args.save:
             save_checkpoint(iteration, model, optimizer, lr_scheduler)
@@ -231,7 +249,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
 
 def finetune(train_valid_datasets_provider, model_provider,
              forward_step=_cross_entropy_forward_step,
-             end_of_epoch_callback_provider=None):
+             end_of_epoch_callback_provider=None,
+             task_collate_fn=None):
     """Main finetune function used across all tasks."""
     args = get_args()
     timers = get_timers()
@@ -244,7 +263,7 @@ def finetune(train_valid_datasets_provider, model_provider,
     if args.epochs > 0:
         train_dataset, valid_dataset = train_valid_datasets_provider()
         train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
-            train_dataset, valid_dataset)
+            train_dataset, valid_dataset, task_collate_fn)
     else:
         args.train_iters = 0
     timers('train/valid/test dataset/dataloder').stop()
@@ -268,8 +287,11 @@ def finetune(train_valid_datasets_provider, model_provider,
     if args.iteration == 0 and args.pretrained_checkpoint is not None:
         original_load = args.load
         args.load = args.pretrained_checkpoint
+        original_rng = args.no_load_rng
+        args.no_load_rng = True
         _ = load_checkpoint(model, None, None)
         args.load = original_load
+        args.no_load_rng = original_rng
         # This is critical when only model is loaded. We should make sure
         # main parameters are also updated.
         optimizer.reload_model_params()
diff --git a/tasks/main.py b/tasks/main.py
index f5bd5ad69..6d8fc8f5f 100644
--- a/tasks/main.py
+++ b/tasks/main.py
@@ -62,6 +62,29 @@ def get_tasks_args(parser):
     group.add_argument('--faiss-topk-retrievals', type=int, default=100,
                        help='Number of blocks to use as top-k during retrieval')
 
+    # finetune for retriever
+    group.add_argument('--eval-micro-batch-size', type=int, default=None,
+                       help='Eval Batch size per model instance (local batch '
+                            'size). Global batch size is local batch size '
+                            'times data parallel size.')
+    group.add_argument('--train-with-neg', action='store_true',
+                       help='Whether to use negative examples during model '
+                        'training')
+    group.add_argument('--train-hard-neg', type=int, default=0,
+                       help='Number of hard negative exmaples to use during '
+                        'training')
+
+
+    # parameters for Av.rank validation method
+    # Following options/arguments have been taken directly from DPR codebase
+    group.add_argument('--val-av-rank-hard-neg', type=int, default=30,
+                        help='Av.rank validation: how many hard negatives to'
+                        ' take from each question pool')
+    group.add_argument('--val-av-rank-other-neg', type=int, default=30,
+                        help='Av.rank validation: how many other negatives to'
+                        ' take from each question pool')
+
+
     return parser
 
 
@@ -81,8 +104,10 @@ def get_tasks_args(parser):
         from glue.finetune import main
     elif args.task in ['LAMBADA', 'WIKITEXT103']:
         from zeroshot_gpt.evaluate import main
-    elif args.task in ['ICT-ZEROSHOT-NQ']:
+    elif args.task in ['ICT-ZEROSHOT-NQ', 'RETRIEVER-EVAL']:
         from orqa.evaluate_orqa import main
+    elif args.task in ['RET-FINETUNE-NQ']:
+        from orqa.supervised.finetune import main
     else:
         raise NotImplementedError('Task {} is not implemented.'.format(
             args.task))
diff --git a/tasks/orqa/README.md b/tasks/orqa/README.md
new file mode 100644
index 000000000..a8e8f8e6f
--- /dev/null
+++ b/tasks/orqa/README.md
@@ -0,0 +1,36 @@
+## End-to-End Training of Neural Retrievers for Open-Domain Question Answering
+
+Below we present the steps to run unsupervised and supervised trainining and evaluation of the retriever for [open domain question answering](https://arxiv.org/abs/2101.00408).
+
+## Retriever Training
+
+#### Unsupervised pretraining
+1. Use `tools/preprocess_data.py` to preprocess the dataset for Inverse Cloze Task (ICT), which we call unsupervised pretraining. This script takes as input a corpus in loose JSON format and creates fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block and multiple blocks per document. Run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document and another with the body.
+
+
+python tools/preprocess_data.py \
+    --input /path/to/corpus.json \
+    --json-keys text title \
+    --split-sentences \
+    --tokenizer-type BertWordPieceLowerCase \
+    --vocab-file /path/to/vocab.txt \
+    --output-prefix corpus_indexed \
+    --workers 10
+
+ +2. The [`examples/pretrain_ict.sh`](../../examples/pretrain_ict.sh) script runs a single GPU 217M parameter biencoder model for ICT retriever training. Single GPU training is primarily intended for debugging purposes, as the code is developed for distributed training. The script uses a pretrained BERT model and we use a total of batch size of 4096 for the ICT training. + +3. Evaluate the pretrained ICT model using [`examples/evaluate_retriever_nq.sh`](../../examples/evaluate_retriever_nq.sh) for [Google's Natural Questions Open dataset](https://arxiv.org/pdf/1906.00300.pdf). + +#### Supervised finetuning + +1. Use the above pretrained ICT model to finetune using [Google's Natural Questions Open dataset](https://github.com/google-research/language/tree/master/language/orqa). The script [`examples/finetune_retriever_distributed.sh`](../../examples/finetune_retriever_distributed.sh) provides an example for how to perform the training. Our finetuning process includes retriever score scaling and longer training (80 epochs) on top [DPR training](https://arxiv.org/abs/2004.04906). + +2. Evaluate the finetuned model using the same evaluation script as mentioned above for the unsupervised model. + +More details on the retriever are available in [our paper](https://arxiv.org/abs/2101.00408). + +## Reader Training + +The reader component will be available soon. + diff --git a/tasks/orqa/evaluate_orqa.py b/tasks/orqa/evaluate_orqa.py index 7e6b26923..87c59ea30 100644 --- a/tasks/orqa/evaluate_orqa.py +++ b/tasks/orqa/evaluate_orqa.py @@ -15,10 +15,8 @@ """Main tasks functionality.""" -import os -import sys - -from megatron import get_args +from megatron import get_args, print_rank_0 +from megatron.indexer import IndexBuilder from tasks.orqa.evaluate_utils import ORQAEvaluator def main(): @@ -28,6 +26,20 @@ def main(): args = get_args() + """ + Create a BlockData data structure by running an IndexBuilder over an + ICT Dataset and then evaluate on NQ task + """ + + print_rank_0("Starting index builder!") + + index_builder = IndexBuilder() + index_builder.build_and_save_index() + print_rank_0("Build and save indices: done!") + + + print_rank_0("Starting evaluations!") + # Set up the model and evaluator evaluator = ORQAEvaluator() diff --git a/tasks/orqa/evaluate_utils.py b/tasks/orqa/evaluate_utils.py index ebee03522..08b1e929b 100644 --- a/tasks/orqa/evaluate_utils.py +++ b/tasks/orqa/evaluate_utils.py @@ -18,13 +18,14 @@ from megatron import get_args, print_rank_0 from megatron.checkpointing import load_biencoder_checkpoint from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset -from tasks.orqa.natural_questions.nq import get_nq_dataset -from tasks.orqa.natural_questions.nq import get_one_epoch_nq_dataloader -from tasks.orqa.natural_questions.nq import process_nq_batch -from tasks.orqa.natural_questions.qa_utils import calculate_matches from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex -from megatron.model.biencoder_model import biencoder_model_provider +from megatron.model.biencoder_model import get_model_provider from megatron.training import get_model +from tasks.orqa.unsupervised.nq import get_nq_dataset +from tasks.orqa.unsupervised.nq import get_one_epoch_nq_dataloader +from tasks.orqa.unsupervised.nq import process_nq_batch +from tasks.orqa.unsupervised.qa_utils import calculate_matches + class ORQAEvaluator(object): def __init__(self): @@ -44,9 +45,8 @@ def __init__(self): if args.biencoder_shared_query_context_model: only_query_model = False - model = get_model(lambda: biencoder_model_provider(only_query_model=\ - only_query_model, biencoder_shared_query_context_model=\ - args.biencoder_shared_query_context_model)) + model = get_model(get_model_provider(only_query_model=only_query_model, + biencoder_shared_query_context_model=args.biencoder_shared_query_context_model)) self.model = load_biencoder_checkpoint(model, only_query_model=only_query_model) diff --git a/tasks/orqa/supervised/data.py b/tasks/orqa/supervised/data.py new file mode 100644 index 000000000..b45a842b6 --- /dev/null +++ b/tasks/orqa/supervised/data.py @@ -0,0 +1,300 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ORQA dataset.""" + +import json +import random +from abc import ABC +from abc import abstractmethod + +import numpy as np +from torch.utils.data import Dataset + +from megatron import print_rank_0, get_args +from megatron.data.biencoder_dataset_utils import make_attention_mask + +def build_token_types_from_context_list(ctx_list, tokenizer, max_seq_length): + ctx_id_list, ctx_types_list = [], [] + for context in ctx_list: + title_ids = tokenizer.tokenize(context['title']) + ctx_ids = tokenizer.tokenize(context['text']) + ctx_ids = title_ids + [tokenizer.sep_id] + ctx_ids + + ctx_ids, ctx_types, _ = build_tokens_types_paddings_from_ids(ctx_ids, + max_seq_length, tokenizer.cls, + tokenizer.sep, tokenizer.pad) + ctx_id_list.append(ctx_ids) + ctx_types_list.append(ctx_types) + + return ctx_id_list, ctx_types_list + + +def build_tokens_types_paddings_from_text(query, context, + tokenizer, max_seq_length): + """Build token types and paddings, trim if needed, and pad if needed.""" + + query_ids = tokenizer.tokenize(query) + query_ids, query_types, query_pad_mask = \ + build_tokens_types_paddings_from_ids(query_ids, max_seq_length, \ + tokenizer.cls, tokenizer.sep, tokenizer.pad) + + # Appending the title of the context at front + extended_ctx_ids = None + if context is not None: + title_ids = tokenizer.tokenize(context['title']) + ctx_ids = tokenizer.tokenize(context['text']) + extended_ctx_ids = title_ids + [tokenizer.sep] + ctx_ids + + ctx_ids, ctx_types, ctx_pad_mask = \ + build_tokens_types_paddings_from_ids(extended_ctx_ids, + max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad) + + return query_ids, query_types, query_pad_mask, \ + ctx_ids, ctx_types, ctx_pad_mask + + +# Similar code tasks/data_utils with some changes +def build_tokens_types_paddings_from_ids(text_ids, max_seq_length, + cls_id, sep_id, pad_id): + """Build token types and paddings, trim if needed, and pad if needed.""" + enc_ids = [] + tokentypes_enc = [] + + # [CLS]. + enc_ids.append(cls_id) + tokentypes_enc.append(0) + + # A. + len_src = len(text_ids) + enc_ids.extend(text_ids) + tokentypes_enc.extend([0] * len_src) + + # Cap the size. + if len(enc_ids) > max_seq_length - 1: + enc_ids = enc_ids[0: max_seq_length - 1] + tokentypes_enc = tokentypes_enc[0: max_seq_length - 1] + + # [SEP]. + enc_ids.append(sep_id) + tokentypes_enc.append(0) + + num_tokens_enc = len(enc_ids) + # Padding. + padding_length = max_seq_length - len(enc_ids) + if padding_length > 0: + enc_ids.extend([pad_id] * padding_length) + tokentypes_enc.extend([pad_id] * padding_length) + + pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length) + pad_mask = np.array(pad_mask, dtype=np.int64) + + return enc_ids, tokentypes_enc, pad_mask + + +def build_sample(query_ids, query_types, query_pad_mask, + ctx_ids, ctx_types, ctx_pad_mask, answers, + neg_ctx_id_list=None, neg_ctx_types_list=None, + include_neg=False): + """Convert to numpy and return a sample consumed by the batch producer.""" + + query_ids = np.array(query_ids, dtype=np.int64) + query_types = np.array(query_types, dtype=np.int64) + query_mask = make_attention_mask(query_ids, query_ids) + + ctx_ids = np.array(ctx_ids, dtype=np.int64) + ctx_types = np.array(ctx_types, dtype=np.int64) + ctx_mask = make_attention_mask(ctx_ids, ctx_ids) + + sample = ({ + 'query': query_ids, + 'query_mask': query_mask, + 'query_types': query_types, + 'query_pad_mask': query_pad_mask, + 'context': ctx_ids, + 'context_mask': ctx_mask, + 'context_types': ctx_types, + 'context_pad_mask': ctx_pad_mask, + 'reference': answers + }) + + if include_neg: + neg_ctx_ids = np.array(neg_ctx_id_list, dtype=np.int64) + neg_ctx_id_types = np.array(neg_ctx_types_list, dtype=np.int64) + neg_ctx_mask = np.array([make_attention_mask(ids, ids) \ + for ids in neg_ctx_ids], dtype=np.int64) + + sample['neg_context'] = neg_ctx_ids + sample['neg_context_types'] = neg_ctx_id_types + sample['neg_context_mask'] = neg_ctx_mask + + return sample + + +class OpenRetrievalAbstractDataset(ABC, Dataset): + """Open Retrieval base dataset class.""" + + def __init__(self, task_name, dataset_name, datapaths, tokenizer, \ + max_seq_length, evaluate=False): + # Store inputs. + args = get_args() + self.evaluate = evaluate + self.val_av_rank_hard_neg = args.val_av_rank_hard_neg + self.val_av_rank_other_neg = args.val_av_rank_other_neg + self.train_with_neg = args.train_with_neg + self.train_hard_neg = args.train_hard_neg + + self.task_name = task_name + self.dataset_name = dataset_name + self.tokenizer = tokenizer + self.max_seq_length = max_seq_length + print_rank_0(' > building {} dataset for {}:'.format(self.task_name, + self.dataset_name)) + # Process the files. + string = ' > paths:' + for path in datapaths: + string += ' ' + path + print_rank_0(string) + self.samples = [] + for datapath in datapaths: + self.samples.extend(self.process_samples_from_single_path(datapath)) + + args = get_args() + if args.sample_rate < 1: # subsample + k = int(len(self.samples) * args.sample_rate) + self.samples = random.sample(self.samples, k) + + print_rank_0(' >> total number of samples: {}'.format( + len(self.samples))) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + raw_sample = self.samples[idx] + + query_ids, query_types, query_pad_mask, ctx_ids, ctx_types, \ + ctx_pad_mask = build_tokens_types_paddings_from_text( \ + raw_sample['question'], raw_sample['pos_context'], \ + self.tokenizer, self.max_seq_length) + + if self.evaluate: + neg_ctx_list = \ + raw_sample['negative_context'][:self.val_av_rank_other_neg] + \ + raw_sample['hard_negative_context'][:self.val_av_rank_hard_neg] + neg_ctx_id_list, neg_ctx_types_list = \ + build_token_types_from_context_list(neg_ctx_list, \ + self.tokenizer, self.max_seq_length) + + elif self.train_with_neg: + hard_negative_ctx = raw_sample['hard_negative_context'] + negative_ctx = raw_sample['negative_context'] + if True: # TODO: fix this or remove this condition + random.shuffle(hard_negative_ctx) + random.shuffle(negative_ctx) + + neg_ctx_list = hard_negative_ctx[:self.train_hard_neg] + # In the Google NQ dataset by DPR paper, there are around more than + # 50 missing hard negatives in training data. + # In those cases, substitute hard negatives by simple negatives. + if len(neg_ctx_list) < self.train_hard_neg: + neg_ctx_list += negative_ctx[:self.train_hard_neg - \ + len(neg_ctx_list)] + + neg_ctx_id_list, neg_ctx_types_list = \ + build_token_types_from_context_list(neg_ctx_list, + self.tokenizer, self.max_seq_length) + else: + neg_ctx_id_list = None + neg_ctx_types_list = None + + sample = build_sample(query_ids, query_types, query_pad_mask, + ctx_ids, ctx_types, ctx_pad_mask, + raw_sample['answers'], + neg_ctx_id_list, neg_ctx_types_list, + include_neg=self.evaluate or self.train_with_neg) + + return sample + + @staticmethod + @abstractmethod + def process_samples_from_single_path(filename): + """Abstract method that takes a filename and + returns a list of dataset samples, each sample being a dict of + {'text': string, 'text': string} + """ + pass + + + +def normalize_question(question): + if question[-1] == '?': + question = question[:-1] + return question + +# The following class reads the datasets for training retriever as +# prepared by the DPR codebase (https://github.com/facebookresearch/DPR) + +class NQSupervisedDataset(OpenRetrievalAbstractDataset): + + def __init__(self, name, datapaths, tokenizer, max_seq_length, \ + evaluate=False): + super().__init__('natural_questions_ret', + name, + datapaths, + tokenizer, + max_seq_length, + evaluate=evaluate) + + @staticmethod + def process_samples_from_single_path(filename): + """"Implement abstract method.""" + print_rank_0(' > Processing {} ...'.format(filename)) + samples = [] + total = 0 + + with open(filename, 'r', encoding="utf-8") as f: + data = json.load(f) + for row in data: + question = normalize_question(row['question']) + pos_context = row['positive_ctxs'][0] + + # Hard Negative Contexts + if len(row['hard_negative_ctxs']) > 0: + hard_neg_context = row['hard_negative_ctxs'] + else: + hard_neg_context = [] + + # Negative Contexts + if len(row['negative_ctxs']) > 0: + neg_context = row['negative_ctxs'] + else: + neg_context = [] + + answers = row['answers'] + sample = {'question': question, + 'pos_context': pos_context, + 'hard_negative_context': hard_neg_context, + 'negative_context': neg_context, + 'answers': answers} + total += 1 + samples.append(sample) + + if total % 5000 == 0: + print_rank_0(' > processed {} so far ...'.format(total)) + + print_rank_0(' >> processed {} samples.'.format(len(samples))) + return samples + diff --git a/tasks/orqa/supervised/eval_utils.py b/tasks/orqa/supervised/eval_utils.py new file mode 100644 index 000000000..67dca512b --- /dev/null +++ b/tasks/orqa/supervised/eval_utils.py @@ -0,0 +1,206 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluation utilities.""" +from collections import OrderedDict +import math +import numpy as np +import time +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader + +from megatron import get_args, print_rank_0 +from megatron import mpu +from megatron.utils import average_losses_across_data_parallel_group +from tasks.finetune_utils import build_data_loader + +def task_collate_fn(batch_data): + # generate batch + batch_size = len(batch_data) + tensorized = OrderedDict() + for d in batch_data: + for k, v in d.items(): + tensorized.setdefault(k, []).append(v) + + tensorized['query'] = torch.LongTensor(tensorized['query']) + tensorized['query_mask'] = torch.LongTensor(tensorized['query_mask']) + tensorized['query_types'] = torch.LongTensor(tensorized['query_types']) + tensorized['query_pad_mask'] = \ + torch.LongTensor(tensorized['query_pad_mask']) + + tensorized['context'] = torch.LongTensor(tensorized['context']) + tensorized['context_mask'] = \ + torch.LongTensor(tensorized['context_mask']) + tensorized['context_types'] = \ + torch.LongTensor(tensorized['context_types']) + tensorized['context_pad_mask'] = \ + torch.LongTensor(tensorized['context_pad_mask']) + + if 'neg_context' in tensorized: + tensorized['neg_context'] = \ + torch.LongTensor(np.concatenate(tensorized['neg_context'])) + tensorized['neg_context_mask'] = \ + torch.LongTensor(np.concatenate(tensorized['neg_context_mask'])) + tensorized['neg_context_types'] = \ + torch.LongTensor(np.concatenate(tensorized['neg_context_types'])) + + return tensorized + + + +def process_batch(batch): + """Process batch and produce inputs for the model.""" + query_tokens = batch['query'].long().cuda() + query_mask = (batch['query_mask'] < 0.5).cuda() + query_types = batch['query_types'].long().cuda() + query_pad_mask = batch['query_pad_mask'].long().cuda() + + context_tokens = batch['context'].long().cuda() + context_mask = (batch['context_mask'] < 0.5).cuda() + context_types = batch['context_types'].long().cuda() + context_pad_mask = batch['context_pad_mask'].long().cuda() + + if 'neg_context' in batch: + neg_context_tokens = batch['neg_context'].long().cuda() + neg_context_mask = (batch['neg_context_mask'] < 0.5).cuda() + neg_context_types = batch['neg_context_types'].long().cuda() + else: + neg_context_tokens = None + neg_context_mask = None + neg_context_types = None + + reference = batch['reference'] + + return query_tokens, query_mask, query_types, query_pad_mask, \ + context_tokens, context_mask, context_types, context_pad_mask, \ + neg_context_tokens, neg_context_mask, neg_context_types, reference + +def accuracy_func_provider(single_dataset_provider, rank0sampler=False): + """Provide function that calculates accuracies.""" + args = get_args() + + print_rank_0("accuracy_func_provider is CALLED") + + # Build dataloaders + datapath = args.valid_data + dataset = single_dataset_provider(datapath) + + drop_last = False + if mpu.get_data_parallel_world_size() > 1 and not rank0sampler: + drop_last = True + + print_rank_0(datapath) + print_rank_0(rank0sampler) + + dataloader = build_data_loader(dataset, + args.eval_micro_batch_size, + num_workers=args.num_workers, + drop_last=drop_last, + task_collate_fn=task_collate_fn) + dataloaders = (dataset.dataset_name, dataloader) + + def metrics_func(model, epoch, output_predictions=False): + print_rank_0('calculating metrics by accuracy func in ORQA...') + + if output_predictions: + assert rank0sampler + names = 'predictions' + name, dataloader = dataloaders + if args.task == "RET-FINETUNE-NQ": + start_time = time.time() + output = retrieval_loss(model, dataloader) + stats_dict, total = output + format_string = "" + for k, v in stats_dict.items(): + format_string += "|{} = {:.2f}".format(k, v / total) + print_rank_0("epoch:{}{}".format(epoch, format_string)) + print_rank_0("taken time to calcuate metrics {:.3f}".format(\ + time.time() - start_time)) + else: + raise AssertionError("{} Task not supported".format(args.task)) + + return metrics_func + + +def retrieval_loss(model, dataloader): + args = get_args() + total = 0 + topk_stats_dict = {'top{}_acc'.format(k): 0 for k in \ + args.retriever_report_topk_accuracies} + stats_dict = dict(rank=0, **topk_stats_dict) + + assert len(model) == 1 + unwrapped_model = model[0] + unwrapped_model.eval() + + with torch.no_grad(): + # For all the batches in the dataset. + for batch in dataloader: + # Run the model forward. + query_tokens, query_mask, query_types, _, \ + context_tokens, context_mask, context_types, _, \ + neg_context_tokens, neg_context_mask, neg_context_types, \ + reference = process_batch(batch) + + query_logits, context_logits = unwrapped_model(query_tokens, + query_mask, query_types, + torch.cat([context_tokens, neg_context_tokens]), + torch.cat([context_mask, neg_context_mask]), + torch.cat([context_types, neg_context_types])) + + retrieval_scores = torch.matmul(query_logits, + torch.transpose(context_logits, 0, 1)) + + if args.retriever_score_scaling: + retrieval_scores = retrieval_scores / \ + math.sqrt(args.hidden_size) + + local_batch_size = query_logits.shape[0] + labels = torch.arange(local_batch_size).long().cuda() + + softmax_scores = F.softmax(retrieval_scores, dim=1) + sorted_vals, sorted_indices = torch.topk(softmax_scores, + k=softmax_scores.shape[1], + sorted=True) + + def topk_accuracy(k): + return torch.cuda.FloatTensor( + [sum([int(labels[i] in sorted_indices[i, :k]) for i in \ + range(local_batch_size)])]) + + def get_rank(): + return torch.cuda.FloatTensor( + [sum([torch.nonzero(labels[i] == sorted_indices[i])[0][0] \ + for i in range(local_batch_size)])]) + + topk_accs = [topk_accuracy(k) for k in \ + args.retriever_report_topk_accuracies] + rank = get_rank() + losses = average_losses_across_data_parallel_group([rank, \ + *topk_accs]) + + # create stats_dict with retrieval loss and all specified + # top-k accuracies + topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \ + zip(args.retriever_report_topk_accuracies, losses[1:])} + temp_stats_dict = dict(rank=losses[0], **topk_acc_dict) + for k in stats_dict.keys(): + stats_dict[k] += temp_stats_dict[k] + total += local_batch_size + + unwrapped_model.train() + + return stats_dict, total diff --git a/tasks/orqa/supervised/finetune.py b/tasks/orqa/supervised/finetune.py new file mode 100644 index 000000000..aed65ac97 --- /dev/null +++ b/tasks/orqa/supervised/finetune.py @@ -0,0 +1,251 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ORQA finetuning/evaluation.""" + +from functools import partial +import sys + +import math +import torch +import torch.nn.functional as F + +from megatron import get_args, get_timers, get_tokenizer +from megatron import mpu, print_rank_0 +from megatron.indexer import IndexBuilder +from megatron.model.biencoder_model import biencoder_model_provider +from megatron.utils import average_losses_across_data_parallel_group +from pretrain_ict import get_group_world_size_rank +from tasks.finetune_utils import finetune +from tasks.orqa.supervised.eval_utils import accuracy_func_provider +from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn +from tasks.orqa.evaluate_utils import ORQAEvaluator + +# input_ is a 2D tensor +def check_and_append_tensor_for_gather(group, rank, world_size, input_): + + # gather the size of the first dimension of the tensor from all ranks + current_length = input_.size()[0] + first_dim = torch.tensor([[current_length]], + device=torch.cuda.current_device()) + input_list = [torch.empty_like(first_dim) for _ in range(world_size)] + input_list[rank].copy_(first_dim) + torch.distributed.all_gather(input_list, first_dim, group=group) + all_input_list = torch.cat(input_list, dim=0).contiguous() + max_length = torch.max(all_input_list) + + # if the size are different than the max, extend the tensor + # accordingly + if max_length > current_length: + padding=tuple([0] * (input_.dim() * 2 - 1)) + \ + tuple([max_length - current_length]) + input_ = F.pad(input=input_, pad=padding) + + return input_ + +def orqa(Dataset): + + def cross_entropy_forward_step(batch, model): + """Simple forward step with cross-entropy loss.""" + timers = get_timers() + tokenizer = get_tokenizer() + + # Get the batch. + timers('batch generator').start() + try: + batch_ = next(batch) + except BaseException: + batch_ = batch + + group, rank, world_size = get_group_world_size_rank() + + query_tokens, query_mask, query_types, query_pad_mask, \ + context_tokens, context_mask, context_types, context_pad_mask, \ + neg_context_tokens, neg_context_mask, neg_context_types, \ + reference = process_batch(batch_) + + timers('batch generator').stop() + local_batch_size = query_tokens.shape[0] + + # Text representation of query and context + query_list, context_list = [], [] + for i in range(local_batch_size): + query_list.append(tokenizer.decode(query_tokens[i].tolist())) + context_list.append(tokenizer.decode(context_tokens[i].tolist())) + + if neg_context_tokens is not None: + neg_context_tokens = check_and_append_tensor_for_gather(group, + rank, world_size, neg_context_tokens) + neg_context_mask = check_and_append_tensor_for_gather(group, + rank, world_size, neg_context_mask) + neg_context_types = check_and_append_tensor_for_gather(group, + rank, world_size, neg_context_types) + + if neg_context_tokens is not None: + context_tokens = torch.cat([context_tokens, neg_context_tokens]) + context_mask = torch.cat([context_mask, neg_context_mask]) + context_types = torch.cat([context_types, neg_context_types]) + + # Forward model. + output_tensor = model(query_tokens, query_mask, + query_types, context_tokens, + context_mask, context_types) + return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens) + + + def cross_entropy_loss_func(query_tokens, context_tokens, output_tensor): + args = get_args() + + local_batch_size = query_tokens.shape[0] + group, rank, world_size = get_group_world_size_rank() + # recall we assert that model_parallel_size == 1 + global_batch_size = world_size * local_batch_size + + query_logits, context_logits = output_tensor + + if world_size > 1: + input_ = torch.empty_like(context_logits).copy_(\ + context_logits).detach_() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank].copy_(input_) + torch.distributed.all_gather(tensor_list, input_, group=group) + + # Check if all-gather happens in order + assert tensor_list[rank].sum().item() == \ + context_logits.sum().item() + + # Preserves the gradient + tensor_list[rank] = context_logits + all_context_logits = torch.cat(tensor_list, dim=0).contiguous() + + # Query tensors + input_ = torch.empty_like(query_logits).copy_(\ + query_logits).detach_() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank].copy_(input_) + torch.distributed.all_gather(tensor_list, input_, group=group) + + # Check if all-gather happens in order + assert tensor_list[rank].sum().item() == query_logits.sum().item() + + # Preserves the gradient + tensor_list[rank] = query_logits + all_query_logits = torch.cat(tensor_list, dim=0).contiguous() + else: + all_query_logits = query_logits + all_context_logits = context_logits + + retrieval_scores = torch.matmul(all_query_logits, + torch.transpose(all_context_logits, 0, 1)) + # Scaling the retrieval scores + if args.retriever_score_scaling: + retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size) + + if args.train_with_neg: + # if the world size is 3, local batch size is 4, and + # local context size is 8, what we want is + # labels = [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19] + labels = [] + local_context_size = context_tokens.shape[0] + for i in range(world_size): + j = i * local_context_size + labels.extend(list(range(j, j + local_batch_size))) + labels = torch.LongTensor(labels).cuda() + assert len(labels) == global_batch_size + else: + labels = torch.arange(global_batch_size).long().cuda() + + # Cross-entropy loss. + softmax_scores = F.log_softmax(retrieval_scores, dim=1) + + loss = F.nll_loss(softmax_scores, labels, reduction='mean') + + max_score, max_idxs = torch.max(softmax_scores, 1) + correct_predictions_count = (max_idxs == labels).sum().float() + + # Reduce loss for logging. + reduced_loss = average_losses_across_data_parallel_group([loss, \ + correct_predictions_count]) + + # Loss scaling for correct losses in Supervised Retrieval + loss = loss * mpu.get_data_parallel_world_size() + + return loss, {'lm loss': reduced_loss[0], + 'correct_prediction_count': reduced_loss[1]} + + + def train_valid_datasets_provider(): + """Build train and validation dataset.""" + args = get_args() + tokenizer = get_tokenizer() + + train_dataset = Dataset('training', + args.train_data, + tokenizer, + args.retriever_seq_length, + evaluate=False) + valid_dataset = Dataset('validation', + args.valid_data, + tokenizer, + args.retriever_seq_length, + evaluate=True) + return train_dataset, valid_dataset + + def model_provider(pre_process=True, post_process=True): + """Build the model.""" + args = get_args() + print_rank_0('building retriever model for {} ...'.format(args.task)) + + model = biencoder_model_provider(only_context_model=False, + only_query_model=False, + biencoder_shared_query_context_model=\ + args.biencoder_shared_query_context_model, + pre_process=pre_process, post_process=post_process) + + return model + + def single_dataset_provider(datapath): + args = get_args() + tokenizer = get_tokenizer() + + name = datapath[0].split('/')[-1].split('.')[0] + return Dataset(name, + datapath, + tokenizer, + args.retriever_seq_length, + evaluate=True) + + def metrics_func_provider(): + """Provide metrics callback function.""" + return accuracy_func_provider(single_dataset_provider) + + """Finetune/evaluate.""" + finetune(train_valid_datasets_provider, + model_provider, + forward_step=cross_entropy_forward_step, + end_of_epoch_callback_provider=metrics_func_provider, + task_collate_fn=task_collate_fn) + +def main(): + args = get_args() + + if args.task == 'RET-FINETUNE-NQ': + from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset + else: + raise NotImplementedError('ORQA task {} is not implemented.'.format( + args.task)) + + orqa(Dataset) + diff --git a/tasks/orqa/natural_questions/nq.py b/tasks/orqa/unsupervised/nq.py similarity index 100% rename from tasks/orqa/natural_questions/nq.py rename to tasks/orqa/unsupervised/nq.py diff --git a/tasks/orqa/natural_questions/qa_utils.py b/tasks/orqa/unsupervised/qa_utils.py similarity index 98% rename from tasks/orqa/natural_questions/qa_utils.py rename to tasks/orqa/unsupervised/qa_utils.py index 24e71e683..811a05834 100644 --- a/tasks/orqa/natural_questions/qa_utils.py +++ b/tasks/orqa/unsupervised/qa_utils.py @@ -22,7 +22,7 @@ from typing import Tuple, List, Dict import regex as re -from tasks.orqa.natural_questions.tokenizers import SimpleTokenizer +from tasks.orqa.unsupervised.tokenizers import SimpleTokenizer logger = logging.getLogger(__name__) diff --git a/tasks/orqa/natural_questions/tokenizers.py b/tasks/orqa/unsupervised/tokenizers.py similarity index 100% rename from tasks/orqa/natural_questions/tokenizers.py rename to tasks/orqa/unsupervised/tokenizers.py diff --git a/tasks/vision/classification.py b/tasks/vision/classification.py index 5232b3f54..71e840757 100644 --- a/tasks/vision/classification.py +++ b/tasks/vision/classification.py @@ -34,13 +34,14 @@ def train_valid_datasets_provider(): ) return train_ds, valid_ds - def model_provider(): + def model_provider(pre_process=True, post_process=True): """Build the model.""" args = get_args() print_rank_0("building classification model for ImageNet ...") - return VitModel(num_classes=args.num_classes, finetune=True) + return VitModel(num_classes=args.num_classes, finetune=True, + pre_process=pre_process, post_process=post_process) """Finetune/evaluate.""" finetune( diff --git a/tasks/vision/eval_utils.py b/tasks/vision/eval_utils.py index aabc04a15..3a194119c 100644 --- a/tasks/vision/eval_utils.py +++ b/tasks/vision/eval_utils.py @@ -16,10 +16,14 @@ """Evaluation utilities.""" import os +from functools import partial + import torch + from megatron import get_args -from megatron import print_rank_0 +from megatron import print_rank_0, print_rank_last from megatron import mpu +from megatron.schedules import get_forward_backward_func from tasks.vision.finetune_utils import build_data_loader from tasks.vision.finetune_utils import process_batch from torchvision import datasets, transforms @@ -56,7 +60,7 @@ def metrics_func(model, epoch): print_rank_0("calculating metrics ...") correct, total = calculate_correct_answers(model, dataloader, epoch) percent = float(correct) * 100.0 / float(total) - print_rank_0( + print_rank_last( " >> |epoch: {}| overall: correct / total = {} / {} = " "{:.4f} %".format(epoch, correct, total, percent) ) @@ -67,29 +71,61 @@ def metrics_func(model, epoch): def calculate_correct_answers(model, dataloader, epoch): """Calculate correct over total answers""" - model.eval() + args = get_args() + forward_backward_func = get_forward_backward_func() + for m in model: + m.eval() + + def loss_func(labels, output_tensor): + logits = output_tensor + + loss_dict = {} + # Compute the correct answers. + predicted = torch.argmax(logits, dim=-1) + corrects = (predicted == labels).float() + # Add to the counters. + loss_dict['total'] = labels.size(0) + loss_dict['correct'] = corrects.sum().item() + + return 0, loss_dict + + #defined inside to capture output_predictions + def correct_answers_forward_step(batch, model): + try: + batch_ = next(batch) + except BaseException: + batch_ = batch + images, labels = process_batch(batch_) + + # Forward model. + args = get_args() + output_tensor = model(images) + + return output_tensor, partial(loss_func, labels) + with torch.no_grad(): # For all the batches in the dataset. total = 0 correct = 0 for _, batch in enumerate(dataloader): - # Run the model forward. - images, labels = process_batch(batch) - logits = model(images).contiguous().float() - # Add output predictions. - # Compute the correct answers. - predicted = torch.argmax(logits, dim=-1) - corrects = (predicted == labels).float() - # Add to the counters. - total += labels.size(0) - correct += corrects.sum().item() - model.train() + + loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model, + optimizer=None, timers=None, forward_only=True) + + for loss_dict in loss_dicts: + total += loss_dict['total'] + correct += loss_dict['correct'] + + for m in model: + m.train() # Reduce. - unreduced = torch.cuda.LongTensor([correct, total]) - torch.distributed.all_reduce(unreduced, group=mpu.get_data_parallel_group()) + if mpu.is_pipeline_last_stage(): + unreduced = torch.cuda.LongTensor([correct, total]) + torch.distributed.all_reduce(unreduced, + group=mpu.get_data_parallel_group()) - # Print on screen. - correct_ans = unreduced[0].item() - total_count = unreduced[1].item() - return correct_ans, total_count + # Print on screen. + correct_ans = unreduced[0].item() + total_count = unreduced[1].item() + return correct_ans, total_count diff --git a/tasks/vision/finetune_utils.py b/tasks/vision/finetune_utils.py index afde4aa89..f9743883c 100644 --- a/tasks/vision/finetune_utils.py +++ b/tasks/vision/finetune_utils.py @@ -17,6 +17,7 @@ import torch import torch.nn.functional as F +from functools import partial from megatron import get_args from megatron import print_rank_0 from megatron import get_timers @@ -38,10 +39,21 @@ def process_batch(batch): return images, labels -def _cross_entropy_forward_step(batch, model, input_tensor): +def cross_entropy_loss_func(labels, output_tensor): + logits = output_tensor + + # Cross-entropy loss. + loss = F.cross_entropy(logits.contiguous().float(), labels) + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + + return loss, {'lm loss': averaged_loss[0]} + + +def _cross_entropy_forward_step(batch, model): """Simple forward step with cross-entropy loss.""" timers = get_timers() - assert input_tensor is None # Get the batch. timers("batch generator").start() @@ -52,16 +64,10 @@ def _cross_entropy_forward_step(batch, model, input_tensor): images, labels = process_batch(batch_) timers("batch generator").stop() - # Forward model. - logits = model(images).contiguous().float() - - # Cross-entropy loss. - loss = F.cross_entropy(logits, labels) - - # Reduce loss for logging. - average_loss = average_losses_across_data_parallel_group([loss]) - - return loss, {"lm loss": average_loss[0]} + # Forward model. + output_tensor = model(images) + + return output_tensor, partial(cross_entropy_loss_func, labels) def build_data_loader(dataset, micro_batch_size, num_workers, drop_last): @@ -103,23 +109,28 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset): """Traing and validation dataloaders.""" args = get_args() - print_rank_0("building train and validation dataloaders ...") + print_rank_0('building train and validation dataloaders ...') # Training dataset. - train_dataloader = build_data_loader( - train_dataset, args.micro_batch_size, args.num_workers, not args.keep_last - ) + train_dataloader = build_data_loader(train_dataset, args.micro_batch_size, + args.num_workers, not args.keep_last) # Set the training iterations. args.train_iters_per_epoch = len(train_dataloader) args.train_iters = args.epochs * args.train_iters_per_epoch # Validation dataset. For this dataset, we do not need to set up # shuffling so we can just use a simple infinite loop. - valid_dataloader_ = build_data_loader( - valid_dataset, args.micro_batch_size, args.num_workers, not args.keep_last - ) + valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size, + args.num_workers, not args.keep_last) valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_) - return train_dataloader, valid_dataloader + # Now that we've built the data loaders, set batch_size arguments + # to the actual batch size the model will see for this dataset. + # This is necessary so pipeline transfers know what size they are + # and the LR schedule, which is based on samples seen, gets set + # correctly. + args.orig_micro_batch_size = args.micro_batch_size + args.orig_global_batch_size = args.global_batch_size + return train_dataloader, valid_dataloader def _train( model, @@ -135,7 +146,8 @@ def _train( timers = get_timers() # Turn on training mode which enables dropout. - model.train() + for m in model: + m.train() # Tracking loss. losses_dict_sum = {} @@ -166,12 +178,16 @@ def _train( start_iteration = 0 # Train for one step. - losses_dict, skipped_iter = train_step( + losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step( forward_step, batch, model, optimizer, lr_scheduler ) iteration += 1 # Logging. + params_norm = None + if args.log_params_norm: + params_norm = calc_params_l2_norm(model) + report_memory_flag = training_log( losses_dict, losses_dict_sum, @@ -180,6 +196,9 @@ def _train( optimizer.get_loss_scale().item(), report_memory_flag, skipped_iter, + grad_norm, + params_norm, + num_zeros_in_grad ) # Autoresume diff --git a/tools/create_doc_index.py b/tools/create_doc_index.py deleted file mode 100644 index 4448d0e29..000000000 --- a/tools/create_doc_index.py +++ /dev/null @@ -1,33 +0,0 @@ -import os -import sys -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), - os.path.pardir))) - -from megatron import print_rank_0 -from megatron.indexer import IndexBuilder -from megatron.initialize import initialize_megatron - - -def main(): - """Create a BlockData data structure by running an IndexBuilder over an ICT Dataset - - Include all args needed for initial model specification - - Other key args: - --block-data-path: path to write to - --ict-load or --realm-load: path to checkpoint with which to embed - --data-path and --titles-data-path: paths for dataset - --indexer-log-interval: reporting interval - --indexer-batch-size: size specific for indexer jobs - - Check README.md for example script - """ - - initialize_megatron(extra_args_provider=None, - args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) - index_builder = IndexBuilder() - index_builder.build_and_save_index() - print_rank_0("Build and save indices: done!") - -if __name__ == "__main__": - main() -