-
Notifications
You must be signed in to change notification settings - Fork 31.9k
[RAG] Add Ray implementation for distributed retrieval #9197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
46 commits
Select commit
Hold shift + click to select a range
ab8bf04
wip
amogkam 455dc98
Merge branch 'master' of github.com:huggingface/transformers into rag…
amogkam 2d11672
wip
amogkam a02fdbf
wip
amogkam 2da166c
wip
7853f19
wip
amogkam b584095
Merge branch 'master' of github.com:huggingface/transformers into rag…
amogkam 8ba7e86
Merge branch 'master' of https://github.com/huggingface/transformers …
amogkam 897d8b7
Merge branch 'rag-ray' of github.com:amogkam/transformers into rag-ray
amogkam dc86027
wip
amogkam 81962a0
wip
amogkam 0628726
wip
amogkam 9c62c31
Merge branch 'rag-ray' of github.com:amogkam/transformers into rag-ray
amogkam 810dd7d
uncomment
amogkam 6118bab
uncomment
amogkam a4a5c79
wip
amogkam fc3cee1
updates
amogkam 48a9dc9
add docstring
amogkam 581d23b
updates
amogkam c034679
fix arg
amogkam 010f25b
fixes
amogkam 03ac6b3
add unit tests
amogkam b9e109a
update readme
amogkam e768ebb
update readme
amogkam 9166ba9
update finetune script
amogkam 90c5668
update test
amogkam 5696b9a
Merge branch 'master' of github.com:huggingface/transformers into rag…
amogkam dce22fa
Merge branch 'master' of github.com:huggingface/transformers into rag…
amogkam 65ee572
add test
amogkam 7fade9a
add ray to test dependencies
amogkam 0fb4a82
separate ray and ray tune
amogkam 532e7d9
formatting
amogkam 22b239a
shutdown ray at end of test
amogkam dd8527b
fix tests
amogkam 7dd354a
formatting
amogkam 7d5b4d0
formatting
amogkam 51b5ef3
even more formatting
amogkam 2dd4e55
address comments
amogkam f1b3d18
formatting
amogkam 8bddc7f
Merge branch 'master' of github.com:huggingface/transformers into rag…
amogkam 6872da8
Merge branch 'master' of github.com:huggingface/transformers into rag…
amogkam 904e53b
add files
amogkam c097bfc
Update examples/research_projects/rag/test_distributed_retriever.py
amogkam e624cf3
address comments
amogkam 0e7769c
addressing comments
amogkam d327fae
Merge branch 'rag-ray2' of github.com:amogkam/transformers into rag-ray2
amogkam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,3 +19,4 @@ pytest | |
| conllu | ||
| sentencepiece != 0.1.92 | ||
| protobuf | ||
| ray | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
154 changes: 154 additions & 0 deletions
154
examples/research_projects/rag/distributed_ray_retriever.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| import logging | ||
| import random | ||
|
|
||
| import ray | ||
| from transformers import RagConfig, RagRetriever, RagTokenizer | ||
| from transformers.file_utils import requires_datasets, requires_faiss | ||
| from transformers.models.rag.retrieval_rag import CustomHFIndex | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class RayRetriever: | ||
| def __init__(self): | ||
| self.initialized = False | ||
|
|
||
| def create_rag_retriever(self, config, question_encoder_tokenizer, generator_tokenizer, index): | ||
| if not self.initialized: | ||
| self.retriever = RagRetriever( | ||
| config, | ||
| question_encoder_tokenizer=question_encoder_tokenizer, | ||
| generator_tokenizer=generator_tokenizer, | ||
| index=index, | ||
| init_retrieval=False, | ||
| ) | ||
| self.initialized = True | ||
|
|
||
| def init_retrieval(self): | ||
| self.retriever.index.init_index() | ||
|
|
||
| def retrieve(self, question_hidden_states, n_docs): | ||
| doc_ids, retrieved_doc_embeds = self.retriever._main_retrieve(question_hidden_states, n_docs) | ||
| return doc_ids, retrieved_doc_embeds | ||
|
|
||
|
|
||
| class RagRayDistributedRetriever(RagRetriever): | ||
| """ | ||
| A distributed retriever built on top of the ``Ray`` API, a library | ||
| for building distributed applications (https://docs.ray.io/en/master/). | ||
| package. During training, all training workers initialize their own | ||
| instance of a `RagRayDistributedRetriever`, and each instance of | ||
| this distributed retriever shares a common set of Retrieval Ray | ||
| Actors (https://docs.ray.io/en/master/walkthrough.html#remote | ||
| -classes-actors) that load the index on separate processes. Ray | ||
| handles the communication between the `RagRayDistributedRetriever` | ||
| instances and the remote Ray actors. If training is done in a | ||
| non-distributed setup, the index will simply be loaded in the same | ||
| process as the training worker and Ray will not be used. | ||
|
|
||
| Args: | ||
| config (:class:`~transformers.RagConfig`): | ||
| The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build. | ||
| question_encoder_tokenizer (:class:`~transformers.PretrainedTokenizer`): | ||
| The tokenizer that was used to tokenize the question. | ||
| It is used to decode the question and then use the generator_tokenizer. | ||
| generator_tokenizer (:class:`~transformers.PretrainedTokenizer`): | ||
| The tokenizer used for the generator part of the RagModel. | ||
| retrieval_workers (:obj:`List[ray.ActorClass(RayRetriever)]`): A list of already initialized `RayRetriever` actors. | ||
| These actor classes run on remote processes and are responsible for performing the index lookup. | ||
| index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration): | ||
| If specified, use this index instead of the one built using the configuration | ||
| """ | ||
|
|
||
| def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, retrieval_workers, index=None): | ||
| if index is not None and index.is_initialized() and len(retrieval_workers) > 0: | ||
| raise ValueError( | ||
| "When using Ray for distributed fine-tuning, " | ||
| "you'll need to provide the paths instead, " | ||
| "as the dataset and the index are loaded " | ||
| "separately. More info in examples/rag/use_own_knowledge_dataset.py " | ||
| ) | ||
| super().__init__( | ||
| config, | ||
| question_encoder_tokenizer=question_encoder_tokenizer, | ||
| generator_tokenizer=generator_tokenizer, | ||
| index=index, | ||
| init_retrieval=False, | ||
| ) | ||
| self.retrieval_workers = retrieval_workers | ||
| if len(self.retrieval_workers) > 0: | ||
| ray.get( | ||
| [ | ||
| worker.create_rag_retriever.remote(config, question_encoder_tokenizer, generator_tokenizer, index) | ||
| for worker in self.retrieval_workers | ||
| ] | ||
| ) | ||
|
|
||
| def init_retrieval(self): | ||
| """ | ||
| Retriever initialization function, needs to be called from the | ||
| training process. This function triggers retrieval initialization | ||
| for all retrieval actors if using distributed setting, or loads | ||
| index into current process if training is not distributed. | ||
| """ | ||
| logger.info("initializing retrieval") | ||
|
|
||
| if len(self.retrieval_workers) > 0: | ||
| ray.get([worker.init_retrieval.remote() for worker in self.retrieval_workers]) | ||
| else: | ||
| # Non-distributed training. Load index into this same process. | ||
| self.index.init_index() | ||
|
|
||
| def retrieve(self, question_hidden_states, n_docs): | ||
| """ | ||
| Retrieves documents for specified ``question_hidden_states``. If | ||
| running training with multiple workers, a random retrieval actor is | ||
| selected to perform the index lookup and return the result. | ||
|
|
||
| Args: | ||
| question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`): | ||
| A batch of query vectors to retrieve with. | ||
| n_docs (:obj:`int`): | ||
| The number of docs retrieved per query. | ||
|
|
||
| Output: | ||
| retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)` | ||
| The retrieval embeddings of the retrieved docs per query. | ||
| doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`) | ||
| The ids of the documents in the index | ||
| doc_dicts (:obj:`List[dict]`): | ||
| The retrieved_doc_embeds examples per query. | ||
| """ | ||
| if len(self.retrieval_workers) > 0: | ||
| # Select a random retrieval actor. | ||
| random_worker = self.retrieval_workers[random.randint(0, len(self.retrieval_workers) - 1)] | ||
| doc_ids, retrieved_doc_embeds = ray.get(random_worker.retrieve.remote(question_hidden_states, n_docs)) | ||
| else: | ||
| doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs) | ||
| return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids) | ||
|
|
||
| @classmethod | ||
| def get_tokenizers(cls, retriever_name_or_path, indexed_dataset=None, **kwargs): | ||
| return super(RagRayDistributedRetriever, cls).get_tokenizers(retriever_name_or_path, indexed_dataset, **kwargs) | ||
|
|
||
| @classmethod | ||
| def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs): | ||
| requires_datasets(cls) | ||
| requires_faiss(cls) | ||
| config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs) | ||
| rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config) | ||
| question_encoder_tokenizer = rag_tokenizer.question_encoder | ||
| generator_tokenizer = rag_tokenizer.generator | ||
| if indexed_dataset is not None: | ||
| config.index_name = "custom" | ||
| index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset) | ||
| else: | ||
| index = cls._build_index(config) | ||
| return cls( | ||
| config, | ||
| question_encoder_tokenizer=question_encoder_tokenizer, | ||
| generator_tokenizer=generator_tokenizer, | ||
| retrieval_workers=actor_handles, | ||
| index=index, | ||
| ) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer copy paste here instead of making a change to the general
rag_retriaval.pyfile. Abstracting that much just to not repeat code is not worth it here IMO. Ideally, I'd like to not have aget_tokenizersclass method at all.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it- I made the change to remove
get_tokenizersand keep everything infrom_pretrained.The reason I did this originally was so that any future changes to
retrieval_rag::from_pretrainedwouldn't also have to be made todistributed_ray_retriever::from_pretrained, since it might be easy to forget to do this in case the tests don't catch it. This is something we just have to keep in mind.