Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
ab8bf04
wip
amogkam Oct 22, 2020
455dc98
Merge branch 'master' of github.com:huggingface/transformers into rag…
amogkam Oct 22, 2020
2d11672
wip
amogkam Oct 23, 2020
a02fdbf
wip
amogkam Oct 29, 2020
2da166c
wip
Nov 5, 2020
7853f19
wip
amogkam Nov 5, 2020
b584095
Merge branch 'master' of github.com:huggingface/transformers into rag…
amogkam Nov 5, 2020
8ba7e86
Merge branch 'master' of https://github.com/huggingface/transformers …
amogkam Nov 5, 2020
897d8b7
Merge branch 'rag-ray' of github.com:amogkam/transformers into rag-ray
amogkam Nov 5, 2020
dc86027
wip
amogkam Nov 5, 2020
81962a0
wip
amogkam Nov 7, 2020
0628726
wip
amogkam Nov 10, 2020
9c62c31
Merge branch 'rag-ray' of github.com:amogkam/transformers into rag-ray
amogkam Nov 10, 2020
810dd7d
uncomment
amogkam Nov 10, 2020
6118bab
uncomment
amogkam Nov 10, 2020
a4a5c79
wip
amogkam Nov 13, 2020
fc3cee1
updates
amogkam Nov 16, 2020
48a9dc9
add docstring
amogkam Nov 16, 2020
581d23b
updates
amogkam Nov 16, 2020
c034679
fix arg
amogkam Nov 16, 2020
010f25b
fixes
amogkam Nov 16, 2020
03ac6b3
add unit tests
amogkam Nov 17, 2020
b9e109a
update readme
amogkam Nov 17, 2020
e768ebb
update readme
amogkam Nov 17, 2020
9166ba9
update finetune script
amogkam Nov 17, 2020
90c5668
update test
amogkam Nov 17, 2020
5696b9a
Merge branch 'master' of github.com:huggingface/transformers into rag…
amogkam Nov 17, 2020
dce22fa
Merge branch 'master' of github.com:huggingface/transformers into rag…
amogkam Nov 23, 2020
65ee572
add test
amogkam Nov 23, 2020
7fade9a
add ray to test dependencies
amogkam Nov 26, 2020
0fb4a82
separate ray and ray tune
amogkam Nov 26, 2020
532e7d9
formatting
amogkam Nov 26, 2020
22b239a
shutdown ray at end of test
amogkam Nov 26, 2020
dd8527b
fix tests
amogkam Nov 26, 2020
7dd354a
formatting
amogkam Nov 26, 2020
7d5b4d0
formatting
amogkam Nov 26, 2020
51b5ef3
even more formatting
amogkam Nov 26, 2020
2dd4e55
address comments
amogkam Nov 30, 2020
f1b3d18
formatting
amogkam Nov 30, 2020
8bddc7f
Merge branch 'master' of github.com:huggingface/transformers into rag…
amogkam Nov 30, 2020
6872da8
Merge branch 'master' of github.com:huggingface/transformers into rag…
amogkam Dec 17, 2020
904e53b
add files
amogkam Dec 17, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/legacy/pytorch-lightning/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ pytest
conllu
sentencepiece != 0.1.92
protobuf
ray
38 changes: 38 additions & 0 deletions examples/research_projects/rag/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,44 @@ python examples/rag/consolidate_rag_checkpoint.py \
```
You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune_rag.py` script.

## Document Retrieval
When running distributed fine-tuning, each training worker needs to retrieve contextual documents
for its input by querying a index loaded into memory. RAG provides two implementations for document retrieval,
one with [`torch.distributed`](https://pytorch.org/docs/stable/distributed.html) communication package and the other
with [`Ray`](https://docs.ray.io/en/master/).

This option can be configured with the `--distributed_retriever` flag which can either be set to `pytorch` or `ray`.
By default this flag is set to `pytorch`.

For the Pytorch implementation, only training worker 0 loads the index into CPU memory, and a gather/scatter pattern is used
to collect the inputs from the other training workers and send back the corresponding document embeddings.

For the Ray implementation, the index is loaded in *separate* process(es). The training workers randomly select which
retriever worker to query. To use Ray for distributed retrieval, you have to set the `--distributed_retriever` arg to `ray`.
To configure the number of retrieval workers (the number of processes that load the index), you can set the `num_retrieval_workers` flag.
Also make sure to start the Ray cluster before running fine-tuning.

```bash
# Start a single-node Ray cluster.
ray start --head

python examples/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
--model_type rag_sequence \
--fp16 \
--gpus 8
--distributed_retriever ray \
--num_retrieval_workers 4

# Stop the ray cluster once fine-tuning has finished.
ray stop
```

Using Ray can lead to retrieval speedups on multi-GPU settings since multiple processes load the index rather than
just the rank 0 training worker. Using Ray also allows you to load the index on GPU since the index is loaded on a separate
processes than the model, while with pytorch distributed retrieval, both are loaded in the same process potentially leading to GPU OOM.

# Evaluation
Our evaluation script enables two modes of evaluation (controlled by the `eval_mode` argument): `e2e` - end2end evaluation, returns EM (exact match) and F1 scores calculated for the downstream task and `retrieval` - which returns precision@k of the documents retrieved for provided inputs.
Expand Down
16 changes: 15 additions & 1 deletion examples/research_projects/rag/_test_finetune_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
require_ray,
require_torch_gpu,
require_torch_multi_gpu,
)
Expand All @@ -29,7 +30,7 @@ def _create_dummy_data(self, data_dir):
with open(os.path.join(data_dir, f"{split}.{field}"), "w") as f:
f.write(content)

def _run_finetune(self, gpus: int):
def _run_finetune(self, gpus: int, distributed_retriever: str = "pytorch"):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)

Expand Down Expand Up @@ -66,6 +67,7 @@ def _run_finetune(self, gpus: int):
--gradient_accumulation_steps 1 \
--distributed-port 8787 \
--use_dummy_dataset 1 \
--distributed_retriever {distributed_retriever} \
""".split()

if gpus > 0:
Expand Down Expand Up @@ -94,3 +96,15 @@ def test_finetune_gpu(self):
def test_finetune_multigpu(self):
result = self._run_finetune(gpus=2)
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)

@require_torch_gpu
@require_ray
def test_finetune_gpu_ray_retrieval(self):
result = self._run_finetune(gpus=1, distributed_retriever="ray")
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)

@require_torch_multi_gpu
@require_ray
def test_finetune_multigpu_ray_retrieval(self):
result = self._run_finetune(gpus=1, distributed_retriever="ray")
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@ class RagPyTorchDistributedRetriever(RagRetriever):
If specified, use this index instead of the one built using the configuration
"""

_init_retrieval = False

def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None):
super().__init__(
config,
question_encoder_tokenizer=question_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
index=index,
init_retrieval=False,
)
self.process_group = None

Expand Down
144 changes: 144 additions & 0 deletions examples/research_projects/rag/distributed_ray_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import logging
import random

import ray
from transformers import RagRetriever


logger = logging.getLogger(__name__)


class RayRetriever:
def __init__(self):
self.initialized = False

def create_rag_retriever(self, config, question_encoder_tokenizer, generator_tokenizer, index):
if not self.initialized:
self.retriever = RagRetriever(
config,
question_encoder_tokenizer=question_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
index=index,
init_retrieval=False,
)
self.initialized = True

def init_retrieval(self):
self.retriever.index.init_index()

def retrieve(self, question_hidden_states, n_docs):
doc_ids, retrieved_doc_embeds = self.retriever._main_retrieve(question_hidden_states, n_docs)
return doc_ids, retrieved_doc_embeds


class RagRayDistributedRetriever(RagRetriever):
"""
A distributed retriever built on top of the ``Ray`` API, a library
for building distributed applications (https://docs.ray.io/en/master/).
package. During training, all training workers initialize their own
instance of a `RagRayDistributedRetriever`, and each instance of
this distributed retriever shares a common set of Retrieval Ray
Actors (https://docs.ray.io/en/master/walkthrough.html#remote
-classes-actors) that load the index on separate processes. Ray
handles the communication between the `RagRayDistributedRetriever`
instances and the remote Ray actors. If training is done in a
non-distributed setup, the index will simply be loaded in the same
process as the training worker and Ray will not be used.

Args:
config (:class:`~transformers.RagConfig`):
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build.
question_encoder_tokenizer (:class:`~transformers.PretrainedTokenizer`):
The tokenizer that was used to tokenize the question.
It is used to decode the question and then use the generator_tokenizer.
generator_tokenizer (:class:`~transformers.PretrainedTokenizer`):
The tokenizer used for the generator part of the RagModel.
retrieval_workers (:obj:`List[ray.ActorClass(RayRetriever)]`): A list of already initialized `RayRetriever` actors.
These actor classes run on remote processes and are responsible for performing the index lookup.
index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration):
If specified, use this index instead of the one built using the configuration
"""

def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, retrieval_workers, index=None):
if index is not None and index.is_initialized() and len(retrieval_workers) > 0:
raise ValueError(
"When using Ray for distributed fine-tuning, "
"you'll need to provide the paths instead, "
"as the dataset and the index are loaded "
"separately. More info in examples/rag/use_own_knowledge_dataset.py "
)
super().__init__(
config,
question_encoder_tokenizer=question_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
index=index,
init_retrieval=False,
)
self.retrieval_workers = retrieval_workers
if len(self.retrieval_workers) > 0:
ray.get(
[
worker.create_rag_retriever.remote(config, question_encoder_tokenizer, generator_tokenizer, index)
for worker in self.retrieval_workers
]
)

def init_retrieval(self):
"""
Retriever initialization function, needs to be called from the
training process. This function triggers retrieval initialization
for all retrieval actors if using distributed setting, or loads
index into current process if training is not distributed.
"""
logger.info("initializing retrieval")

if len(self.retrieval_workers) > 0:
ray.get([worker.init_retrieval.remote() for worker in self.retrieval_workers])
else:
# Non-distributed training. Load index into this same process.
self.index.init_index()

def retrieve(self, question_hidden_states, n_docs):
"""
Retrieves documents for specified ``question_hidden_states``. If
running training with multiple workers, a random retrieval actor is
selected to perform the index lookup and return the result.

Args:
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`):
A batch of query vectors to retrieve with.
n_docs (:obj:`int`):
The number of docs retrieved per query.

Output:
retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`
The retrieval embeddings of the retrieved docs per query.
doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`)
The ids of the documents in the index
doc_dicts (:obj:`List[dict]`):
The retrieved_doc_embeds examples per query.
"""
if len(self.retrieval_workers) > 0:
# Select a random retrieval actor.
random_worker = self.retrieval_workers[random.randint(0, len(self.retrieval_workers) - 1)]
doc_ids, retrieved_doc_embeds = ray.get(random_worker.retrieve.remote(question_hidden_states, n_docs))
else:
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)

@classmethod
def get_tokenizers(cls, retriever_name_or_path, indexed_dataset=None, **kwargs):
return super(RagRayDistributedRetriever, cls).get_tokenizers(retriever_name_or_path, indexed_dataset, **kwargs)

@classmethod
def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs):
config, question_encoder_tokenizer, generator_tokenizer, index = cls.get_tokenizers(
retriever_name_or_path, indexed_dataset, **kwargs
)
return cls(
config,
question_encoder_tokenizer=question_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
retrieval_workers=actor_handles,
index=index,
)
Loading