Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
8b04e0e
ICT zeroshot evaluation code
mpatwary Mar 10, 2021
661553f
made more generic, aligned with other tasks
mpatwary Mar 11, 2021
43c9137
Fixed based on review recoemmendation
mpatwary Mar 19, 2021
4056539
fixed another issue
mpatwary Mar 19, 2021
a5acbf5
Merge branch 'main' into main_retriver_merge_ict_eval
mpatwary Mar 20, 2021
10ff060
implementing DPR
mpatwary Apr 9, 2021
cdde433
Merge branch 'main' into main_retriver_merge_dpr
mpatwary Apr 9, 2021
06076c7
implementation dpr
mpatwary Apr 23, 2021
957d1c9
Merge branch 'main' into main_retriver_merge_dpr
Apr 26, 2021
b9fcb7b
adding dpr code
Apr 29, 2021
8004731
removed commnets
Apr 29, 2021
f415dc8
removed commnets
Apr 29, 2021
a8d172b
removed commnets
Apr 29, 2021
220637f
DPR evaluation debugging
May 11, 2021
d2d5086
DPR ongoing
May 11, 2021
6d03d7a
DPR finetune and evaluation
May 12, 2021
f926720
fixing model evaluation of retriver
May 12, 2021
5409341
added pre ad post process
May 12, 2021
7e335e1
added pre ad post process
May 12, 2021
f64977f
evaluation works!
May 13, 2021
dca47cf
debugging DPR
May 14, 2021
3f75537
fix copy-n-paste error
stas00 May 17, 2021
07ca952
Typo fix in readme
devrimcavusoglu May 18, 2021
2dae74b
t5 fixes
stas00 May 18, 2021
4a09bb3
Merge branch 'main' into main_retriver_merge_dpr
mpatwary May 18, 2021
7a0710e
before cleaning the comments
mpatwary May 18, 2021
ccae9db
vit pipeline fixes
kvareddy May 18, 2021
2eaf6c7
cleaning the code
mpatwary May 18, 2021
2529380
additional cleaning
mpatwary May 19, 2021
8e44d61
renaming the folders
mpatwary May 19, 2021
113c636
Add temporary assert to finetuning until it can be fixed.
jaredcasper May 19, 2021
7577931
Fixed issues with ICT pretraining
mpatwary May 19, 2021
dfb6a9b
updated the evaluation script for retriver
mpatwary May 19, 2021
f21a662
updated the evaluation script for retriver
mpatwary May 19, 2021
a41e478
updated the evaluation script for retriver
mpatwary May 19, 2021
825375c
updated the evaluation script for retriver
mpatwary May 19, 2021
217f54b
Merge branch 'finetune_assert' into 'main'
shoeybi May 19, 2021
d078e54
added exit interval for finetuning
mpatwary May 20, 2021
63121a9
updating the scripts
mpatwary May 20, 2021
fda81a2
updating no load rng
mpatwary May 25, 2021
01fc083
Merge branch 'vit_pipeline_fixes' into 'main'
jaredcasper Jun 1, 2021
83c4d95
Merge branch 'main_retriver_merge_dpr' into 'main'
jaredcasper Jun 1, 2021
c7c65bb
updating script
mpatwary Jun 3, 2021
84eb016
Merge branch 'main' into main_retriver_merge_dpr
mpatwary Jun 3, 2021
3dadd16
Update T5 scripts
deepakn94 Jun 7, 2021
04c79f3
resolved hang issue
mpatwary Jun 8, 2021
ebfbfce
fixed the tensor size miss-mass issue
mpatwary Jun 9, 2021
e46f326
fixed the evaluation hangs
mpatwary Jun 10, 2021
a983cab
Adding readme
mpatwary Jun 10, 2021
d562d7b
Adding readme
mpatwary Jun 10, 2021
1095d7e
Adding readme
mpatwary Jun 10, 2021
bab5cc4
Adding readme
mpatwary Jun 10, 2021
8661ca2
Adding readme
mpatwary Jun 10, 2021
293554a
Adding readme
mpatwary Jun 10, 2021
e287bf0
Adding readme
mpatwary Jun 10, 2021
c45109e
Adding readme
mpatwary Jun 10, 2021
473127f
Clean up README.md a bit
jaredcasper Jun 10, 2021
2845047
addressed comments
mpatwary Jun 10, 2021
98113c6
Merge branch 'main_retriver_merge_dpr' of ssh://gitlab-master.nvidia.…
mpatwary Jun 10, 2021
598d7ee
Merge branch 'main_retriver_merge_dpr' into 'main'
jaredcasper Jun 10, 2021
2be1e51
Merge branch 't5_scripts' into 'main'
jaredcasper Jun 10, 2021
9d350c9
updated readme
mpatwary Jun 10, 2021
baf2e2a
updated readme
mpatwary Jun 10, 2021
32da2e7
updated readme
mpatwary Jun 10, 2021
4c92ca8
updated readme
mpatwary Jun 10, 2021
82b69e8
Merge branch 'main_retriver_merge_dpr' into 'main'
jaredcasper Jun 11, 2021
7898c9a
Merge branch 't5' of https://github.com/stas00/Megatron-LM into githu…
jaredcasper Jun 11, 2021
e1318f0
Merge branch 'typo-fix' of https://github.com/devrimcavusoglu/Megatro…
jaredcasper Jun 11, 2021
4a35d50
Merge branch 'patch-1' of https://github.com/stas00/Megatron-LM into …
jaredcasper Jun 11, 2021
90e0a0d
Merge branch 'github-pr' into 'main'
jaredcasper Jun 11, 2021
3ed5da6
rough carbon tracker test
stas00 Jul 6, 2021
318ef29
fix bug when restarting with no eval in round 1
stas00 Jul 23, 2021
6039727
Merge remote-tracking branch 'origin/master' into cc
stas00 Jul 24, 2021
00bc3ea
wip
stas00 Jul 24, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ python tools/preprocess_data.py \

The output will be two files named, in this case, `my-bert_text_sentence.bin` and `my-bert_text_sentence.idx`. The `--data-path` specified in later BERT training is the full path and new filename, but without the file extension.

For T5 use the same preprocessing as BERT, perhaps renaming it to:
<pre>
--output-prefix my-t5 \
</pre>

Some minor modifications are required for GPT data preprocessing, namely, the addition of a merge table, an end-of-document token, removal of sentence splitting, and a change to the tokenizer type:
<pre>
python tools/preprocess_data.py \
Expand Down Expand Up @@ -247,13 +252,14 @@ T5_ARGS="--num-layers 24 \
--micro-batch-size 16 \
--global-batch-size 2048 \
--vocab-file $VOCAB_FILE \
--vocab-extra-ids 100 \
--split 949,50,1 \
--fp16"

OUTPUT_ARGS=&#60;same as those in <a href="#bert-pretraining">BERT pretraining</a> above&#62;

python pretrain_t5.py \
$BERT_ARGS \
$T5_ARGS \
$OUTPUT_ARGS \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
Expand Down
32 changes: 0 additions & 32 deletions examples/create_embeddings.sh

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
#!/bin/bash

# Evaluate natural question test data given Wikipedia embeddings and pretrained
# ICT model
# ICT model or a finetuned model for Natural Question task

# Datasets can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py

EVIDENCE_DATA_DIR=<Specify path of Wikipedia dataset>
EMBEDDING_PATH=<Specify path of the embeddings>
CHECKPOINT_PATH=<Specify path of pretrained ICT model>
CHECKPOINT_PATH=<Specify path of pretrained ICT model or finetuned model>

QA_FILE=<Path of the natural question test dataset>
QA_FILE=<Path of the natural question dev or test dataset>

python tasks/main.py \
--task ICT-ZEROSHOT-NQ \
--task RETRIEVER-EVAL \
--tokenizer-type BertWordPieceLowerCase \
--num-layers 12 \
--hidden-size 768 \
Expand All @@ -29,8 +29,10 @@ python tasks/main.py \
--retriever-seq-length 256 \
--vocab-file bert-vocab.txt\
--qa-data-test ${QA_FILE} \
--num-workers 2 \
--faiss-use-gpu \
--retriever-report-topk-accuracies 1 5 20 100 \
--fp16
--fp16 \
--indexer-log-interval 1000 \
--indexer-batch-size 128


56 changes: 56 additions & 0 deletions examples/finetune_retriever_distributed.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/bin/bash

# Finetune a BERT or pretrained ICT model using Google natural question data
# Datasets can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py

WORLD_SIZE=8

DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
--nnodes 1 \
--node_rank 0 \
--master_addr localhost \
--master_port 6000"

CHECKPOINT_PATH=<Specify path for the finetuned retriever model>

# Load either of the below
BERT_LOAD_PATH=<Path of BERT pretrained model>
PRETRAINED_CHECKPOINT=<Path of Pretrained ICT model>

python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--task RET-FINETUNE-NQ \
--train-with-neg \
--train-hard-neg 1 \
--pretrained-checkpoint ${PRETRAINED_CHECKPOINT} \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--tensor-model-parallel-size 1 \
--tokenizer-type BertWordPieceLowerCase \
--train-data nq-train.json \
--valid-data nq-dev.json \
--save ${CHECKPOINT_PATH} \
--load ${CHECKPOINT_PATH} \
--vocab-file bert-vocab.txt \
--bert-load ${BERT_LOAD_PATH} \
--save-interval 5000 \
--log-interval 10 \
--eval-interval 20000 \
--eval-iters 100 \
--indexer-log-interval 1000 \
--faiss-use-gpu \
--DDP-impl torch \
--fp16 \
--retriever-report-topk-accuracies 1 5 10 20 100 \
--seq-length 512 \
--retriever-seq-length 256 \
--max-position-embeddings 512 \
--retriever-score-scaling \
--epochs 80 \
--micro-batch-size 8 \
--eval-micro-batch-size 16 \
--indexer-batch-size 128 \
--lr 2e-5 \
--lr-warmup-fraction 0.01 \
--weight-decay 1e-1
5 changes: 3 additions & 2 deletions examples/pretrain_t5.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ python pretrain_t5.py \
--encoder-seq-length 512 \
--decoder-seq-length 128 \
--micro-batch-size 16 \
--global-batch-size 2048 \
--global-batch-size 16 \
--max-position-embeddings 512 \
--train-iters 1000000 \
--lr-decay-iters 1000000 \
Expand All @@ -35,4 +35,5 @@ python pretrain_t5.py \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
--fp16
--fp16 \
--vocab-extra-ids 100
5 changes: 3 additions & 2 deletions examples/pretrain_t5_distributed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--encoder-seq-length 512 \
--decoder-seq-length 128 \
--micro-batch-size 16 \
--global-batch-size 2048 \
--global-batch-size 128 \
--max-position-embeddings 512 \
--train-iters 1000000 \
--lr-decay-iters 1000000 \
Expand All @@ -44,4 +44,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
--fp16
--fp16 \
--vocab-extra-ids 100
6 changes: 3 additions & 3 deletions examples/pretrain_t5_distributed_with_mp.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--encoder-seq-length 512 \
--decoder-seq-length 128 \
--micro-batch-size 16 \
--global-batch-size 2048 \
--seq-length 512 \
--global-batch-size 128 \
--max-position-embeddings 512 \
--train-iters 1000000 \
--lr-decay-iters 1000000 \
Expand All @@ -45,4 +44,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
--fp16
--fp16 \
--vocab-extra-ids 100
2 changes: 2 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,8 @@ def _add_training_args(parser):
help='Run optimizer on CPU')
group.add_argument('--cpu_torch_adam', action='store_true',
help='Use Torch Adam as optimizer on CPU.')
group.add_argument('--codecarbon-dir', type=str, default=None,
help='Write CodeCarbon logs to this directory.')

return parser

Expand Down
2 changes: 1 addition & 1 deletion megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def _compare(arg_name, old_arg_name=None):
_compare('num_layers')
_compare('hidden_size')
_compare('num_attention_heads')
_compare('max_position_embeddings')
if args.vocab_file:
_compare('max_position_embeddings')
_compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size')
_compare('tokenizer_type')
Expand Down
12 changes: 7 additions & 5 deletions megatron/indexer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import sys
import time
import torch
import torch.distributed as dist

from megatron import get_args
from megatron import get_args, print_rank_0
from megatron import mpu
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, OpenRetreivalDataStore
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.model.biencoder_model import get_model_provider
from megatron.training import get_model


Expand All @@ -29,7 +30,6 @@ def __init__(self):
# need to know whether we're using a REALM checkpoint (args.load)
# or ICT checkpoint
assert not (args.load and args.ict_load)
#self.using_realm_chkpt = args.ict_load is None

self.log_interval = args.indexer_log_interval
self.batch_size = args.indexer_batch_size
Expand All @@ -47,8 +47,8 @@ def load_attributes(self):
if self.biencoder_shared_query_context_model:
only_context_model = False

model = get_model(lambda: biencoder_model_provider(only_context_model \
= only_context_model, biencoder_shared_query_context_model = \
model = get_model(get_model_provider(only_context_model=\
only_context_model, biencoder_shared_query_context_model=\
self.biencoder_shared_query_context_model))

self.model = load_biencoder_checkpoint(model,
Expand Down Expand Up @@ -85,6 +85,7 @@ def build_and_save_index(self):
"""
assert len(self.model) == 1
unwrapped_model = self.model[0]

while not hasattr(unwrapped_model, 'embed_text'):
unwrapped_model = unwrapped_model.module

Expand All @@ -103,6 +104,7 @@ def build_and_save_index(self):
context_logits = unwrapped_model.embed_text(
unwrapped_model.context_model, context_tokens, context_mask,
context_types)

context_logits = detach(context_logits)
row_id = detach(row_id)

Expand Down
2 changes: 1 addition & 1 deletion megatron/learning_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def get_lr(self):
else:
raise Exception('{} decay style is not supported.'.format(
self.decay_style))

return self.min_lr + coeff * delta_lr


Expand Down
Loading