Skip to content

Conversation

@amogkam
Copy link
Collaborator

@amogkam amogkam commented Dec 18, 2020

What does this PR do?

This PR adds a new distributed retriever implementation for RAG built on Ray, as an alternative to the current retriever implementation that uses torch.distributed. With Ray it's possible to load the index on multiple processes instead of just the rank 0 training worker, allowing fine tuning to scale out better to multiple GPUs, and also allowing the index to potentially be fit in GPU memory. This also removes a core dependency on Pytorch, allowing a Tensorflow implementation of finetune.py.

This PR also makes changes to support finetune.py with Pytorch Lightning >v1.0.

A benchmark of Pytorch distribtued retrieval vs. Ray distributed retrieval
image

Implementation Details

In the current Pytorch retrieval implementation, the index is loaded once on just the rank 0 training workers. Training worker 0 gathers the inputs from all other workers, performs the index lookup, and scatters the results back to the other workers.
image

With the Ray implementation, the index is loaded on separate processes, which are referred to as Ray actors. Each training worker randomly selects a retrieval actor to query for documents and Ray handles all the communication between the processes. Because the index can be loaded in multiple processes, training can scale up since no synchronization needs to happen for the index lookup.
image

Note that Pytorch Lightning is still handling distributed training, but Ray manages distributed retrieval. Because PTL calls the entire training script under the hood multiple times, we have to use Ray's named actors feature (https://docs.ray.io/en/master/actors.html?highlight=named%20actors#named-actors) allowing the retrieval actors to be referenced by all training processes. The use of named actors is necessitated by how PTL handles distributed training, and a simpler approach could probably be used for a Tensorflow implentation.

Testing Strategy

Unit tests were added to test_distributed_retriever.py. Note that the local Ray cluster for the tests had to be started with local_mode=True because the test file modifies sys.path and these changes are not propagated to remote processes. See https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder for more info.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to the it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR.

@amogkam
Copy link
Collaborator Author

amogkam commented Dec 18, 2020

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for closing/reopening, I have just one last nit!


@classmethod
def get_tokenizers(cls, retriever_name_or_path, indexed_dataset=None, **kwargs):
return super(RagRayDistributedRetriever, cls).get_tokenizers(retriever_name_or_path, indexed_dataset, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer copy paste here instead of making a change to the general rag_retriaval.py file. Abstracting that much just to not repeat code is not worth it here IMO. Ideally, I'd like to not have a get_tokenizers class method at all.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it- I made the change to remove get_tokenizers and keep everything in from_pretrained.

The reason I did this originally was so that any future changes to retrieval_rag::from_pretrained wouldn't also have to be made to distributed_ray_retriever::from_pretrained, since it might be easy to forget to do this in case the tests don't catch it. This is something we just have to keep in mind.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for working on this! The PR looks great except for the get_tokenizers(...) class method. Could we try to not split up the from_pretrained(...) in retrieval_rag.py at the cost of maybe copy pasting some code?

Also before merging I'd like @lhoestq to take a quick look - he probably knows best here.

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks all good thanks !

Looking forward to using it :)

@patrickvonplaten
Copy link
Contributor

Nice, good to merge then!

@patrickvonplaten patrickvonplaten merged commit a4b21cd into huggingface:master Dec 21, 2020
@richardliaw
Copy link
Collaborator

Awesome, thank you so much for the reviews @lhoestq @patrickvonplaten -- happy holidays!

@amogkam
Copy link
Collaborator Author

amogkam commented Dec 21, 2020

Thanks guys!

@shamanez
Copy link

@amogkam @patrickvonplaten I need some help to implement an end-to-end retrieval training feature for the rag with Ray.

How can I run document encoding and indexing with an updated doc-encoder (context encoder network that kept frozen in the original RAG) using a Ray actor separated from the main training process?

How can I access the document index inside Ray actors during the training incase I want to update the index, say in every 5000 steps.

@richardliaw
Copy link
Collaborator

@shamanez could you open a new issue to track this?

@shamanez
Copy link

shamanez commented Feb 11, 2021

@richardliaw

I have already opened one a few weeks ago. Please refer to this issue

I added a new issue explaining the exact problem in this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants