diff --git a/examples/lightning_base.py b/examples/lightning_base.py index 0c4913e15fe1..1c50765178b0 100644 --- a/examples/lightning_base.py +++ b/examples/lightning_base.py @@ -384,6 +384,8 @@ def generic_train( train_params["distributed_backend"] = "ddp" train_params["accumulate_grad_batches"] = args.accumulate_grad_batches + train_params["accelerator"] = extra_train_kwargs.get("accelerator", None) + train_params["profiler"] = extra_train_kwargs.get("profiler", None) trainer = pl.Trainer.from_argparse_args( args, diff --git a/examples/rag/README.md b/examples/rag/README.md index 65b126666ecf..38e2071f2445 100644 --- a/examples/rag/README.md +++ b/examples/rag/README.md @@ -7,8 +7,8 @@ to the retriever to extract relevant context documents. The documents are then p Such contextualized inputs are passed to the generator. Read more about RAG at https://arxiv.org/abs/2005.11401. -# Finetuning +# Finetuning Our finetuning logic is based on scripts from [`examples/seq2seq`](https://github.com/huggingface/transformers/tree/master/examples/seq2seq). We accept training data in the same format as specified there - we expect a directory consisting of 6 text files: ```bash @@ -20,10 +20,10 @@ test.source test.target ``` -A sample finetuning command (run ` ./examples/rag/finetune.py --help` to list all available options): +A sample finetuning command (run ` ./examples/rag/finetune_rag.py --help` to list all available options): ```bash -python examples/rag/finetune.py \ +python examples/rag/finetune_rag.py \ --data_dir $DATA_DIR \ --output_dir $OUTPUT_DIR \ --model_name_or_path $MODEL_NAME_OR_PATH \ @@ -45,7 +45,7 @@ python examples/rag/consolidate_rag_checkpoint.py \ --question_encoder_name_or_path facebook/dpr-question_encoder-single-nq-base \ --dest path/to/checkpoint ``` -You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune.py` script. +You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune_rag.py` script. # Evaluation @@ -130,3 +130,29 @@ python examples/rag/eval_rag.py \ --print_predictions \ --recalculate \ # adding this parameter will force recalculating predictions even if predictions_path already exists ``` + +# Use your own knowledge source + +By default, RAG uses the English Wikipedia as a knowledge source, known as the 'wiki_dpr' dataset. +With `use_custom_knowledge_dataset.py` you can build your own knowledge source, *e.g.* for RAG. + +For instance, if documents are serialized as tab-separated csv files with the columns "title" and "text", one can use `use_own_knowledge_dataset.py` as follows: +```bash +python examples/rag/use_own_knowledge_dataset.py \ + --csv_path path/to/my_csv \ + --output_dir path/to/my_knowledge_dataset \ +``` + +The created outputs in `path/to/my_knowledge_dataset` can then be used to finetune RAG as follows: +```bash +python examples/rag/finetune_rag.py \ + --data_dir $DATA_DIR \ + --output_dir $OUTPUT_DIR \ + --model_name_or_path $MODEL_NAME_OR_PATH \ + --model_type rag_sequence \ + --fp16 \ + --gpus 8 + --index_name custom + --passages_path path/to/data/my_knowledge_dataset + --index_path path/to/my_knowledge_dataset_hnsw_index.faiss +``` \ No newline at end of file diff --git a/examples/rag/callbacks.py b/examples/rag/callbacks_rag.py similarity index 97% rename from examples/rag/callbacks.py rename to examples/rag/callbacks_rag.py index 099cf2bbdfac..ce30db88cdd6 100644 --- a/examples/rag/callbacks.py +++ b/examples/rag/callbacks_rag.py @@ -8,7 +8,7 @@ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.utilities import rank_zero_only -from utils import save_json +from utils_rag import save_json def count_trainable_parameters(model): @@ -38,7 +38,7 @@ def get_checkpoint_callback(output_dir, metric): monitor=f"val_{metric}", mode="max", save_top_k=3, - period=0, # maybe save a checkpoint every time val is run, not just end of epoch. + period=1, # maybe save a checkpoint every time val is run, not just end of epoch. ) return checkpoint_callback diff --git a/examples/rag/distributed_retriever.py b/examples/rag/distributed_retriever.py index a931f183aa26..cedd2c33409f 100644 --- a/examples/rag/distributed_retriever.py +++ b/examples/rag/distributed_retriever.py @@ -40,7 +40,6 @@ def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, inde generator_tokenizer=generator_tokenizer, index=index, ) - self.process_group = None def init_retrieval(self, distributed_port: int): diff --git a/examples/rag/finetune.py b/examples/rag/finetune_rag.py similarity index 89% rename from examples/rag/finetune.py rename to examples/rag/finetune_rag.py index 9882b9e2dc12..b62da19688ce 100644 --- a/examples/rag/finetune.py +++ b/examples/rag/finetune_rag.py @@ -1,12 +1,10 @@ """Finetuning script for RAG models. Adapted from examples.seq2seq.finetune.py""" import argparse -import glob import logging import os import sys import time -import warnings from collections import defaultdict from pathlib import Path from typing import Any, Dict, List, Tuple @@ -15,29 +13,31 @@ import pytorch_lightning as pl import torch import torch.distributed as dist +from pytorch_lightning.accelerators.ddp_accelerator import DDPAccelerator +from pytorch_lightning.cluster_environments import TorchElasticEnvironment from torch.utils.data import DataLoader from transformers import ( AutoConfig, AutoTokenizer, BartForConditionalGeneration, + BatchEncoding, RagConfig, RagSequenceForGeneration, RagTokenForGeneration, RagTokenizer, T5ForConditionalGeneration, - get_linear_schedule_with_warmup, ) from transformers import logging as transformers_logging -from callbacks import ( # noqa: E402 # isort:skipq +from callbacks_rag import ( # noqa: E402 # isort:skipq get_checkpoint_callback, get_early_stopping_callback, Seq2SeqLoggingCallback, ) from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip -from utils import ( # noqa: E402 # isort:skip +from utils_rag import ( # noqa: E402 # isort:skip calculate_exact_match, flatten_list, get_git_info, @@ -67,6 +67,30 @@ def __init__(self, *args, **kwargs): self.__dict__ = self +# In PTL >v1.0, `init_ddp_connection` method in the `LightningModule` +# is no longer used, and is moved into DDPAccelerator instead. +# We override DDPAccelerator to add our custom logic for initializing the +# retriever. +# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/backends/test_accelerator_connector.py + + +class CustomAccel(DDPAccelerator): + def __init__(self, trainer=None, **kwargs): + # Trainer is set later. + super().__init__(trainer, **kwargs) + + def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True): + logger.info("Custom init_ddp_connection.") + module = self.trainer.model + if self.cluster_environment is None: + self.cluster_environment = TorchElasticEnvironment() + self.distributed_port = module.hparams.distributed_port + os.environ["MASTER_PORT"] = str(self.distributed_port) + super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks) + if module.is_rag_model: + module.model.rag.retriever.init_retrieval(self.distributed_port) + + class GenerativeQAModule(BaseTransformer): mode = "generative_qa" loss_names = ["loss"] @@ -91,23 +115,24 @@ def __init__(self, hparams, **kwargs): config = config_class.from_pretrained(hparams.model_name_or_path) # set retriever parameters - config.index_name = args.index_name or config.index_name - config.passages_path = args.passages_path or config.passages_path - config.index_path = args.index_path or config.index_path + config.index_name = hparams.index_name or config.index_name + config.passages_path = hparams.passages_path or config.passages_path + config.index_path = hparams.index_path or config.index_path + config.use_dummy_dataset = hparams.use_dummy_dataset # set extra_model_params for generator configs and load_model extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout") if self.is_rag_model: - if args.prefix is not None: - config.generator.prefix = args.prefix + if hparams.prefix is not None: + config.generator.prefix = hparams.prefix config.label_smoothing = hparams.label_smoothing hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator) retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config) model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever) prefix = config.question_encoder.prefix else: - if args.prefix is not None: - config.prefix = args.prefix + if hparams.prefix is not None: + config.prefix = hparams.prefix hparams, config = set_extra_model_params(extra_model_params, hparams, config) model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config) prefix = config.prefix @@ -152,11 +177,9 @@ def __init__(self, hparams, **kwargs): self.num_workers = hparams.num_workers self.distributed_port = self.hparams.distributed_port - def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True): - logger.info("Custom init_ddp_connection.") - os.environ["MASTER_PORT"] = str(self.distributed_port) - super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks) - if self.is_rag_model: + # For single GPU training, init_ddp_connection is not called. + # So we need to initialize the retrievers here. + if hparams.gpus <= 1: self.model.retriever.init_retrieval(self.distributed_port) def forward(self, input_ids, **kwargs): @@ -270,6 +293,7 @@ def calc_generative_metrics(self, preds, target) -> Dict: def _generative_step(self, batch: dict) -> dict: start_time = time.time() + batch = BatchEncoding(batch).to(device=self.model.device) generated_ids = self.model.generate( batch["input_ids"], attention_mask=batch["attention_mask"], @@ -322,17 +346,6 @@ def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) def train_dataloader(self) -> DataLoader: dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True) - t_total = ( - (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus))) - // self.hparams.accumulate_grad_batches - * float(self.hparams.max_epochs) - ) - scheduler = get_linear_schedule_with_warmup( - self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total - ) - if max(scheduler.get_last_lr()) > 0: - warnings.warn("All learning rates are 0") - self.lr_scheduler = scheduler return dataloader def val_dataloader(self) -> DataLoader: @@ -429,10 +442,24 @@ def add_retriever_specific_args(parser): default=None, help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`", ) + parser.add_argument( + "--use_dummy_dataset", + type=bool, + default=False, + help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`", + ) return parser -def main(args, model=None) -> GenerativeQAModule: +def main(args=None, model=None) -> GenerativeQAModule: + + parser = argparse.ArgumentParser() + parser = pl.Trainer.add_argparse_args(parser) + parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd()) + parser = GenerativeQAModule.add_retriever_specific_args(parser) + + args = args or parser.parse_args() + Path(args.output_dir).mkdir(exist_ok=True) if model is None: model: GenerativeQAModule = GenerativeQAModule(args) @@ -461,6 +488,7 @@ def main(args, model=None) -> GenerativeQAModule: if args.early_stopping_patience >= 0 else False ) + trainer: pl.Trainer = generic_train( model, args, @@ -468,31 +496,17 @@ def main(args, model=None) -> GenerativeQAModule: checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric), early_stopping_callback=es_callback, logger=logger, + accelerator=CustomAccel() if args.gpus > 1 else None, ) pickle_save(model.hparams, model.output_dir / "hparams.pkl") if not args.do_predict: return model - model.hparams.test_checkpoint = "" - checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True))) - if checkpoints: - model.hparams.test_checkpoint = checkpoints[-1] - trainer.resume_from_checkpoint = checkpoints[-1] # best checkpoint - trainer.logger.log_hyperparams(model.hparams) - # test() without a model tests using the best checkpoint automatically trainer.test() - return model if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser = pl.Trainer.add_argparse_args(parser) - parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd()) - parser = GenerativeQAModule.add_retriever_specific_args(parser) - - args = parser.parse_args() - - main(args) + main() diff --git a/examples/rag/finetune.sh b/examples/rag/finetune_rag.sh similarity index 96% rename from examples/rag/finetune.sh rename to examples/rag/finetune_rag.sh index ce82070aaa3d..577b6ebd0dbd 100755 --- a/examples/rag/finetune.sh +++ b/examples/rag/finetune_rag.sh @@ -4,7 +4,7 @@ export PYTHONPATH="../":"${PYTHONPATH}" # A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path # run ./examples/rag/finetune.sh --help to see all the possible options -python examples/rag/finetune.py \ +python examples/rag/finetune_rag.py \ --data_dir $DATA_DIR \ --output_dir $OUTPUT_DIR \ --model_name_or_path $MODEL_NAME_OR_PATH \ diff --git a/examples/rag/test_finetune_rag.py b/examples/rag/test_finetune_rag.py new file mode 100644 index 000000000000..164ecfd93211 --- /dev/null +++ b/examples/rag/test_finetune_rag.py @@ -0,0 +1,96 @@ +import json +import logging +import os +import sys +from pathlib import Path + +import finetune_rag +from transformers.file_utils import is_apex_available +from transformers.testing_utils import ( + TestCasePlus, + execute_subprocess_async, + require_torch_gpu, + require_torch_multi_gpu, +) + + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger() + + +class RagFinetuneExampleTests(TestCasePlus): + def _create_dummy_data(self, data_dir): + os.makedirs(data_dir, exist_ok=True) + contents = {"source": "What is love ?", "target": "life"} + n_lines = {"train": 12, "val": 2, "test": 2} + for split in ["train", "test", "val"]: + for field in ["source", "target"]: + content = "\n".join([contents[field]] * n_lines[split]) + with open(os.path.join(data_dir, f"{split}.{field}"), "w") as f: + f.write(content) + + def _run_finetune(self, gpus: int): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + tmp_dir = self.get_auto_remove_tmp_dir() + output_dir = os.path.join(tmp_dir, "output") + data_dir = os.path.join(tmp_dir, "data") + self._create_dummy_data(data_dir=data_dir) + + testargs = f""" + --data_dir {data_dir} \ + --output_dir {output_dir} \ + --model_name_or_path facebook/rag-sequence-base \ + --model_type rag_sequence \ + --do_train \ + --do_predict \ + --n_val -1 \ + --val_check_interval 1.0 \ + --train_batch_size 2 \ + --eval_batch_size 1 \ + --max_source_length 25 \ + --max_target_length 25 \ + --val_max_target_length 25 \ + --test_max_target_length 25 \ + --label_smoothing 0.1 \ + --dropout 0.1 \ + --attention_dropout 0.1 \ + --weight_decay 0.001 \ + --adam_epsilon 1e-08 \ + --max_grad_norm 0.1 \ + --lr_scheduler polynomial \ + --learning_rate 3e-04 \ + --num_train_epochs 1 \ + --warmup_steps 4 \ + --gradient_accumulation_steps 1 \ + --distributed-port 8787 \ + --use_dummy_dataset 1 \ + """.split() + + if gpus > 0: + testargs.append(f"--gpus={gpus}") + if is_apex_available(): + testargs.append("--fp16") + else: + testargs.append("--gpus=0") + testargs.append("--distributed_backend=ddp_cpu") + testargs.append("--num_processes=2") + + cmd = [sys.executable, str(Path(finetune_rag.__file__).resolve())] + testargs + execute_subprocess_async(cmd, env=self.get_env()) + + metrics_save_path = os.path.join(output_dir, "metrics.json") + with open(metrics_save_path) as f: + result = json.load(f) + return result + + @require_torch_gpu + def test_finetune_gpu(self): + result = self._run_finetune(gpus=1) + self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2) + + @require_torch_multi_gpu + def test_finetune_multigpu(self): + result = self._run_finetune(gpus=2) + self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2) diff --git a/examples/rag/use_own_knowledge_dataset.py b/examples/rag/use_own_knowledge_dataset.py index fd465e6900c7..269765caab86 100644 --- a/examples/rag/use_own_knowledge_dataset.py +++ b/examples/rag/use_own_knowledge_dataset.py @@ -7,7 +7,7 @@ from typing import List, Optional import torch -from datasets import load_dataset +from datasets import Features, Sequence, Value, load_dataset import faiss from transformers import ( @@ -82,10 +82,14 @@ def main( # And compute the embeddings ctx_encoder = DPRContextEncoder.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name).to(device=device) ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name) + new_features = Features( + {"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))} + ) # optional, save as float32 instead of float64 to save space dataset = dataset.map( partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer), batched=True, batch_size=processing_args.batch_size, + features=new_features, ) # And finally save your dataset diff --git a/examples/rag/utils.py b/examples/rag/utils_rag.py similarity index 100% rename from examples/rag/utils.py rename to examples/rag/utils_rag.py diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 31de9b3922b0..f96868cb5434 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -556,7 +556,9 @@ def forward( if encoder_outputs is None: if has_to_retrieve: - question_enc_outputs = self.question_encoder(input_ids, attention_mask=attention_mask) + question_enc_outputs = self.question_encoder( + input_ids, attention_mask=attention_mask, return_dict=True + ) question_encoder_last_hidden_state = question_enc_outputs[0] # hidden states of question encoder retriever_outputs = self.retriever( @@ -616,6 +618,7 @@ def forward( decoder_attention_mask=decoder_attention_mask, past_key_values=past_key_values, use_cache=use_cache, + return_dict=True, ) if not has_to_retrieve: diff --git a/src/transformers/models/rag/retrieval_rag.py b/src/transformers/models/rag/retrieval_rag.py index fb47fd20596a..8db18a1d65d9 100644 --- a/src/transformers/models/rag/retrieval_rag.py +++ b/src/transformers/models/rag/retrieval_rag.py @@ -196,7 +196,7 @@ def __init__(self, vector_size, dataset, index_initialized=False): self.dataset = dataset self._index_initialized = index_initialized self._check_dataset_format(with_index=index_initialized) - dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True) + dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True, dtype="float32") def _check_dataset_format(self, with_index: bool): if not isinstance(self.dataset, Dataset):