-
Notifications
You must be signed in to change notification settings - Fork 31.9k
[RAG] Add Ray implementation for distributed retrieval #8583
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi ! This looks awesome :) |
|
@lhoestq yes that sounds great! |
|
Yes indeed ! Feel free to set this PR to ready for review Also it looks like the CI fails because of a failed import of You should also add |
|
@lhoestq CI is passing now! |
|
@lhoestq any ETA on when this PR can get reviewed? Thanks |
|
Hi ! I've already started to look at the changes and it looks pretty good so far :) I'll finish my review soon, probably tomorrow |
|
Awesome thanks! |
lhoestq
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really good ! thank you for adding ray support for RAG fine-tuning :)
And the speed up comparing to using only one worker for retrieval is pretty cool.
I left a few comments, mainly about separating the pytorch tests from the ray tests.
examples/rag/README.md
Outdated
| python examples/rag/finetune.py \ | ||
| --data_dir $DATA_DIR \ | ||
| --output_dir $OUTPUT_DIR \ | ||
| --model_name_or_path $MODEL_NAME_OR_PATH \ | ||
| --model_type rag_sequence \ | ||
| --fp16 \ | ||
| --gpus 8 | ||
| --distributed_retriever ray \ | ||
| --num_retrieval_workers 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add an example for torch as well ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently distributed_retriever defaults to pytorch so an example command for this would just be the same as the command earlier in the Readme. I added a sentence saying that the default is pytorch though.
| import ray # noqa: F401 | ||
|
|
||
| _has_ray = True | ||
| try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adding ray integration here cc @LysandreJik
|
@sgugger it would be cool if you could review as this changes some things in the trainer/integrations. |
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are other instance of is_ray_available to change to is_ray_tune_available if we go with the name change:
- in
integrations.py, inside the functionhp_paramsanddefault_hp_search_backend - in
trainer_utils.py, inside the functiondefault_hp_space_ray
The main __init__ should also be updated to provide the two functions.
|
Hi there, sorry for the delay. Could you close and reopen your PR? Because of a bad force-push on our side, the diff has become unreadable. Also, the examples folder has slightly changed structure, so you might need to move the folder. Ping me, @patrickvonplaten and @LysandreJik on the PR you reopen and we'll look at it quickly. |
|
Opened a new one here: #9197! |
What does this PR do?
This PR adds a new distributed retriever implementation for RAG built on Ray, as an alternative to the current retriever implementation that uses torch.distributed. With Ray it's possible to load the index on multiple processes instead of just the rank 0 training worker, allowing fine tuning to scale out better to multiple GPUs, and also allowing the index to potentially be fit in GPU memory. This also removes a core dependency on Pytorch, allowing a Tensorflow implementation of
finetune.py.This PR also makes changes to support finetune.py with Pytorch Lightning >v1.0.
A benchmark of Pytorch distribtued retrieval vs. Ray distributed retrieval

Implementation Details
In the current Pytorch retrieval implementation, the index is loaded once on just the rank 0 training workers. Training worker 0 gathers the inputs from all other workers, performs the index lookup, and scatters the results back to the other workers.

With the Ray implementation, the index is loaded on separate processes, which are referred to as Ray actors. Each training worker randomly selects a retrieval actor to query for documents and Ray handles all the communication between the processes. Because the index can be loaded in multiple processes, training can scale up since no synchronization needs to happen for the index lookup.

Note that Pytorch Lightning is still handling distributed training, but Ray manages distributed retrieval. Because PTL calls the entire training script under the hood multiple times, we have to use Ray's named actors feature (https://docs.ray.io/en/master/actors.html?highlight=named%20actors#named-actors) allowing the retrieval actors to be referenced by all training processes. The use of named actors is necessitated by how PTL handles distributed training, and a simpler approach could probably be used for a Tensorflow implentation.
Testing Strategy
Unit tests were added to
test_distributed_retriever.py. Note that the local Ray cluster for the tests had to be started withlocal_mode=Truebecause the test file modifiessys.pathand these changes are not propagated to remote processes. See https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder for more info.Fixes # (issue)
Before submitting
Pull Request section?
to the it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR.