-
Notifications
You must be signed in to change notification settings - Fork 31.9k
[RAG] Add Ray implementation for distributed retrieval #9197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for closing/reopening, I have just one last nit!
Co-authored-by: Sylvain Gugger <[email protected]>
|
|
||
| @classmethod | ||
| def get_tokenizers(cls, retriever_name_or_path, indexed_dataset=None, **kwargs): | ||
| return super(RagRayDistributedRetriever, cls).get_tokenizers(retriever_name_or_path, indexed_dataset, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer copy paste here instead of making a change to the general rag_retriaval.py file. Abstracting that much just to not repeat code is not worth it here IMO. Ideally, I'd like to not have a get_tokenizers class method at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it- I made the change to remove get_tokenizers and keep everything in from_pretrained.
The reason I did this originally was so that any future changes to retrieval_rag::from_pretrained wouldn't also have to be made to distributed_ray_retriever::from_pretrained, since it might be easy to forget to do this in case the tests don't catch it. This is something we just have to keep in mind.
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for working on this! The PR looks great except for the get_tokenizers(...) class method. Could we try to not split up the from_pretrained(...) in retrieval_rag.py at the cost of maybe copy pasting some code?
Also before merging I'd like @lhoestq to take a quick look - he probably knows best here.
lhoestq
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks all good thanks !
Looking forward to using it :)
|
Nice, good to merge then! |
|
Awesome, thank you so much for the reviews @lhoestq @patrickvonplaten -- happy holidays! |
|
Thanks guys! |
|
@amogkam @patrickvonplaten I need some help to implement an end-to-end retrieval training feature for the rag with Ray. How can I run document encoding and indexing with an updated doc-encoder (context encoder network that kept frozen in the original RAG) using a Ray actor separated from the main training process? How can I access the document index inside Ray actors during the training incase I want to update the index, say in every 5000 steps. |
|
@shamanez could you open a new issue to track this? |
What does this PR do?
This PR adds a new distributed retriever implementation for RAG built on Ray, as an alternative to the current retriever implementation that uses torch.distributed. With Ray it's possible to load the index on multiple processes instead of just the rank 0 training worker, allowing fine tuning to scale out better to multiple GPUs, and also allowing the index to potentially be fit in GPU memory. This also removes a core dependency on Pytorch, allowing a Tensorflow implementation of
finetune.py.This PR also makes changes to support finetune.py with Pytorch Lightning >v1.0.
A benchmark of Pytorch distribtued retrieval vs. Ray distributed retrieval

Implementation Details
In the current Pytorch retrieval implementation, the index is loaded once on just the rank 0 training workers. Training worker 0 gathers the inputs from all other workers, performs the index lookup, and scatters the results back to the other workers.

With the Ray implementation, the index is loaded on separate processes, which are referred to as Ray actors. Each training worker randomly selects a retrieval actor to query for documents and Ray handles all the communication between the processes. Because the index can be loaded in multiple processes, training can scale up since no synchronization needs to happen for the index lookup.

Note that Pytorch Lightning is still handling distributed training, but Ray manages distributed retrieval. Because PTL calls the entire training script under the hood multiple times, we have to use Ray's named actors feature (https://docs.ray.io/en/master/actors.html?highlight=named%20actors#named-actors) allowing the retrieval actors to be referenced by all training processes. The use of named actors is necessitated by how PTL handles distributed training, and a simpler approach could probably be used for a Tensorflow implentation.
Testing Strategy
Unit tests were added to
test_distributed_retriever.py. Note that the local Ray cluster for the tests had to be started withlocal_mode=Truebecause the test file modifiessys.pathand these changes are not propagated to remote processes. See https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder for more info.Fixes # (issue)
Before submitting
Pull Request section?
to the it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR.