-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prototype torch.distributed integration #158
Conversation
rl_games/common/a2c_common.py
Outdated
@@ -966,7 +973,7 @@ def train(self): | |||
update_time = 0 | |||
if self.multi_gpu: | |||
should_exit_t = torch.tensor(should_exit).float() | |||
self.hvd.broadcast_value(should_exit_t, 'should_exit') | |||
# self.hvd.broadcast_value(should_exit_t, 'should_exit') # what is the purpose of this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here is a chance that one job will exit a little bit earlier, as result other jobs will crash.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method seems to be deprecated by horovod: the closest method I can find is https://horovod.readthedocs.io/en/stable/api.html#horovod.torch.broadcast_.
What exactly is this broadcast_value
doing? If rank 0's should_exit_t=False
and rank 1's should_exit_t=True
, would broadcast_value
overwrite rank 1's should_exit_t
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
@@ -0,0 +1,84 @@ | |||
params: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw could you use breakout from envpool?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Envpool is actually more tricky. When prototyping with envpool it's actually slower with multi-GPU, at least out of the box. This is because envpool uses different threads and have complex interactions with these threads that are a bit difficult to control. For this reason, I have chosen the regular gym API for controlled performance.
cc @markelsanz14, I prototyped the |
7a49df2
to
86f5e82
Compare
This PR Prototype
torch.distributed
integration for multi GPU