Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ccc3661
replace init_ddp_connection for index init
lhoestq Nov 12, 2020
a16b78b
style
lhoestq Nov 12, 2020
a528279
add finetune test
lhoestq Nov 12, 2020
10de28b
add test data
lhoestq Nov 13, 2020
fdd59fb
move generate tensors to device
lhoestq Nov 13, 2020
c1cd3c1
add test on EM metric
lhoestq Nov 13, 2020
ae1366c
style
lhoestq Nov 13, 2020
adf56d4
allow multi process test
lhoestq Nov 13, 2020
ddf2eca
keep gloo process group for retrieval
lhoestq Nov 17, 2020
73194c7
add multi-gpu test
lhoestq Nov 17, 2020
aebdcb8
use custom accelerator
lhoestq Nov 17, 2020
f307b9f
clean test finetune
lhoestq Nov 17, 2020
30cd148
minor
lhoestq Nov 17, 2020
46f8660
style
lhoestq Nov 17, 2020
572b070
style
lhoestq Nov 17, 2020
3e5cc12
typo
lhoestq Nov 17, 2020
dcf42e3
use python call instead of imported main fumction
lhoestq Nov 20, 2020
0219b3a
return_dict fix in modeling_rag
lhoestq Nov 20, 2020
a2ace92
use float32 in retrieval
lhoestq Nov 20, 2020
90bee66
store as float32 as well in the custom knowledge dataset example
lhoestq Nov 20, 2020
d6147dd
style
lhoestq Nov 20, 2020
16d9dc9
rename to finetune_rag
lhoestq Nov 20, 2020
a080425
style
lhoestq Nov 20, 2020
611568c
update readme
lhoestq Nov 20, 2020
9652768
rename utils and callbacks to utils_rag and callbacks_rag
lhoestq Nov 20, 2020
ae0c3a3
fix test
lhoestq Nov 20, 2020
22ce5d5
patrick's comments
lhoestq Nov 20, 2020
45e733d
generate dummy data in the finetue test script
lhoestq Nov 20, 2020
b40c6f1
remove dummy data files
lhoestq Nov 20, 2020
0a71fa0
style
lhoestq Nov 20, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/lightning_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ def generic_train(
train_params["distributed_backend"] = "ddp"

train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
train_params["profiler"] = extra_train_kwargs.get("profiler", None)

trainer = pl.Trainer.from_argparse_args(
args,
Expand Down
34 changes: 30 additions & 4 deletions examples/rag/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ to the retriever to extract relevant context documents. The documents are then p
Such contextualized inputs are passed to the generator.

Read more about RAG at https://arxiv.org/abs/2005.11401.
# Finetuning

# Finetuning

Our finetuning logic is based on scripts from [`examples/seq2seq`](https://github.com/huggingface/transformers/tree/master/examples/seq2seq). We accept training data in the same format as specified there - we expect a directory consisting of 6 text files:
```bash
Expand All @@ -20,10 +20,10 @@ test.source
test.target
```

A sample finetuning command (run ` ./examples/rag/finetune.py --help` to list all available options):
A sample finetuning command (run ` ./examples/rag/finetune_rag.py --help` to list all available options):

```bash
python examples/rag/finetune.py \
python examples/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
Expand All @@ -45,7 +45,7 @@ python examples/rag/consolidate_rag_checkpoint.py \
--question_encoder_name_or_path facebook/dpr-question_encoder-single-nq-base \
--dest path/to/checkpoint
```
You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune.py` script.
You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune_rag.py` script.


# Evaluation
Expand Down Expand Up @@ -130,3 +130,29 @@ python examples/rag/eval_rag.py \
--print_predictions \
--recalculate \ # adding this parameter will force recalculating predictions even if predictions_path already exists
```

# Use your own knowledge source

By default, RAG uses the English Wikipedia as a knowledge source, known as the 'wiki_dpr' dataset.
With `use_custom_knowledge_dataset.py` you can build your own knowledge source, *e.g.* for RAG.

For instance, if documents are serialized as tab-separated csv files with the columns "title" and "text", one can use `use_own_knowledge_dataset.py` as follows:
```bash
python examples/rag/use_own_knowledge_dataset.py \
--csv_path path/to/my_csv \
--output_dir path/to/my_knowledge_dataset \
```

The created outputs in `path/to/my_knowledge_dataset` can then be used to finetune RAG as follows:
```bash
python examples/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
--model_type rag_sequence \
--fp16 \
--gpus 8
--index_name custom
--passages_path path/to/data/my_knowledge_dataset
--index_path path/to/my_knowledge_dataset_hnsw_index.faiss
```
4 changes: 2 additions & 2 deletions examples/rag/callbacks.py → examples/rag/callbacks_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only

from utils import save_json
from utils_rag import save_json


def count_trainable_parameters(model):
Expand Down Expand Up @@ -38,7 +38,7 @@ def get_checkpoint_callback(output_dir, metric):
monitor=f"val_{metric}",
mode="max",
save_top_k=3,
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why go from 0 to 1?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I changed that to speed the the test and forgot to remove it. I can just modify the validation frequency

)
return checkpoint_callback

Expand Down
1 change: 0 additions & 1 deletion examples/rag/distributed_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, inde
generator_tokenizer=generator_tokenizer,
index=index,
)

self.process_group = None

def init_retrieval(self, distributed_port: int):
Expand Down
104 changes: 59 additions & 45 deletions examples/rag/finetune.py → examples/rag/finetune_rag.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""Finetuning script for RAG models. Adapted from examples.seq2seq.finetune.py"""

import argparse
import glob
import logging
import os
import sys
import time
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
Expand All @@ -15,29 +13,31 @@
import pytorch_lightning as pl
import torch
import torch.distributed as dist
from pytorch_lightning.accelerators.ddp_accelerator import DDPAccelerator
from pytorch_lightning.cluster_environments import TorchElasticEnvironment
from torch.utils.data import DataLoader

from transformers import (
AutoConfig,
AutoTokenizer,
BartForConditionalGeneration,
BatchEncoding,
RagConfig,
RagSequenceForGeneration,
RagTokenForGeneration,
RagTokenizer,
T5ForConditionalGeneration,
get_linear_schedule_with_warmup,
)
from transformers import logging as transformers_logging


from callbacks import ( # noqa: E402 # isort:skipq
from callbacks_rag import ( # noqa: E402 # isort:skipq
get_checkpoint_callback,
get_early_stopping_callback,
Seq2SeqLoggingCallback,
)
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
from utils import ( # noqa: E402 # isort:skip
from utils_rag import ( # noqa: E402 # isort:skip
calculate_exact_match,
flatten_list,
get_git_info,
Expand Down Expand Up @@ -67,6 +67,30 @@ def __init__(self, *args, **kwargs):
self.__dict__ = self


# In PTL >v1.0, `init_ddp_connection` method in the `LightningModule`
# is no longer used, and is moved into DDPAccelerator instead.
# We override DDPAccelerator to add our custom logic for initializing the
# retriever.
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/backends/test_accelerator_connector.py


class CustomAccel(DDPAccelerator):
def __init__(self, trainer=None, **kwargs):
# Trainer is set later.
super().__init__(trainer, **kwargs)

def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
logger.info("Custom init_ddp_connection.")
module = self.trainer.model
if self.cluster_environment is None:
self.cluster_environment = TorchElasticEnvironment()
self.distributed_port = module.hparams.distributed_port
os.environ["MASTER_PORT"] = str(self.distributed_port)
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
if module.is_rag_model:
module.model.rag.retriever.init_retrieval(self.distributed_port)


class GenerativeQAModule(BaseTransformer):
mode = "generative_qa"
loss_names = ["loss"]
Expand All @@ -91,23 +115,24 @@ def __init__(self, hparams, **kwargs):
config = config_class.from_pretrained(hparams.model_name_or_path)

# set retriever parameters
config.index_name = args.index_name or config.index_name
config.passages_path = args.passages_path or config.passages_path
config.index_path = args.index_path or config.index_path
config.index_name = hparams.index_name or config.index_name
config.passages_path = hparams.passages_path or config.passages_path
config.index_path = hparams.index_path or config.index_path
config.use_dummy_dataset = hparams.use_dummy_dataset

# set extra_model_params for generator configs and load_model
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout")
if self.is_rag_model:
if args.prefix is not None:
config.generator.prefix = args.prefix
if hparams.prefix is not None:
config.generator.prefix = hparams.prefix
config.label_smoothing = hparams.label_smoothing
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
prefix = config.question_encoder.prefix
else:
if args.prefix is not None:
config.prefix = args.prefix
if hparams.prefix is not None:
config.prefix = hparams.prefix
hparams, config = set_extra_model_params(extra_model_params, hparams, config)
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config)
prefix = config.prefix
Expand Down Expand Up @@ -152,11 +177,9 @@ def __init__(self, hparams, **kwargs):
self.num_workers = hparams.num_workers
self.distributed_port = self.hparams.distributed_port

def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
logger.info("Custom init_ddp_connection.")
os.environ["MASTER_PORT"] = str(self.distributed_port)
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
if self.is_rag_model:
# For single GPU training, init_ddp_connection is not called.
# So we need to initialize the retrievers here.
if hparams.gpus <= 1:
self.model.retriever.init_retrieval(self.distributed_port)

def forward(self, input_ids, **kwargs):
Expand Down Expand Up @@ -270,6 +293,7 @@ def calc_generative_metrics(self, preds, target) -> Dict:

def _generative_step(self, batch: dict) -> dict:
start_time = time.time()
batch = BatchEncoding(batch).to(device=self.model.device)
generated_ids = self.model.generate(
batch["input_ids"],
attention_mask=batch["attention_mask"],
Expand Down Expand Up @@ -322,17 +346,6 @@ def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False)

def train_dataloader(self) -> DataLoader:
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
t_total = (
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
// self.hparams.accumulate_grad_batches
* float(self.hparams.max_epochs)
)
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
)
if max(scheduler.get_last_lr()) > 0:
warnings.warn("All learning rates are 0")
self.lr_scheduler = scheduler
return dataloader

def val_dataloader(self) -> DataLoader:
Expand Down Expand Up @@ -429,10 +442,24 @@ def add_retriever_specific_args(parser):
default=None,
help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
)
parser.add_argument(
"--use_dummy_dataset",
type=bool,
default=False,
help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
)
return parser


def main(args, model=None) -> GenerativeQAModule:
def main(args=None, model=None) -> GenerativeQAModule:

parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
parser = GenerativeQAModule.add_retriever_specific_args(parser)

args = args or parser.parse_args()

Path(args.output_dir).mkdir(exist_ok=True)
if model is None:
model: GenerativeQAModule = GenerativeQAModule(args)
Expand Down Expand Up @@ -461,38 +488,25 @@ def main(args, model=None) -> GenerativeQAModule:
if args.early_stopping_patience >= 0
else False
)

trainer: pl.Trainer = generic_train(
model,
args,
logging_callback=Seq2SeqLoggingCallback(),
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
early_stopping_callback=es_callback,
logger=logger,
accelerator=CustomAccel() if args.gpus > 1 else None,
)
pickle_save(model.hparams, model.output_dir / "hparams.pkl")

if not args.do_predict:
return model

model.hparams.test_checkpoint = ""
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
if checkpoints:
model.hparams.test_checkpoint = checkpoints[-1]
trainer.resume_from_checkpoint = checkpoints[-1] # best checkpoint
trainer.logger.log_hyperparams(model.hparams)

# test() without a model tests using the best checkpoint automatically
trainer.test()

return model


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
parser = GenerativeQAModule.add_retriever_specific_args(parser)

args = parser.parse_args()

main(args)
main()
2 changes: 1 addition & 1 deletion examples/rag/finetune.sh → examples/rag/finetune_rag.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export PYTHONPATH="../":"${PYTHONPATH}"
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
# run ./examples/rag/finetune.sh --help to see all the possible options

python examples/rag/finetune.py \
python examples/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
Expand Down
Loading