Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototype torch.distributed integration #158

Closed
wants to merge 0 commits into from

Conversation

vwxyzjn
Copy link
Contributor

@vwxyzjn vwxyzjn commented May 20, 2022

This PR Prototype torch.distributed integration for multi GPU

@@ -966,7 +973,7 @@ def train(self):
update_time = 0
if self.multi_gpu:
should_exit_t = torch.tensor(should_exit).float()
self.hvd.broadcast_value(should_exit_t, 'should_exit')
# self.hvd.broadcast_value(should_exit_t, 'should_exit') # what is the purpose of this?
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here is a chance that one job will exit a little bit earlier, as result other jobs will crash.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method seems to be deprecated by horovod: the closest method I can find is https://horovod.readthedocs.io/en/stable/api.html#horovod.torch.broadcast_.

What exactly is this broadcast_value doing? If rank 0's should_exit_t=False and rank 1's should_exit_t=True, would broadcast_value overwrite rank 1's should_exit_t?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

@@ -0,0 +1,84 @@
params:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw could you use breakout from envpool?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Envpool is actually more tricky. When prototyping with envpool it's actually slower with multi-GPU, at least out of the box. This is because envpool uses different threads and have complex interactions with these threads that are a bit difficult to control. For this reason, I have chosen the regular gym API for controlled performance.

@vwxyzjn
Copy link
Contributor Author

vwxyzjn commented May 24, 2022

cc @markelsanz14, I prototyped the torch.distributed integration but it's only 6% faster. I still feel I am missing the bottleneck somewhere because the prototype with CleanRL was like 25% faster

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants