Skip to content

Commit

Permalink
Add a broadcast for the initial parameters of central_value_net in mu…
Browse files Browse the repository at this point in the history
…lti-GPU/node training. (#297)
  • Loading branch information
annan-tang authored Sep 11, 2024
1 parent 7f9cd1e commit 59d4c40
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,8 +1047,12 @@ def train(self):
torch.cuda.set_device(self.local_rank)
print("====================broadcasting parameters")
model_params = [self.model.state_dict()]
if self.has_central_value:
model_params.append(self.central_value_net.state_dict())
dist.broadcast_object_list(model_params, 0)
self.model.load_state_dict(model_params[0])
if self.has_central_value:
self.central_value_net.load_state_dict(model_params[1])

while True:
epoch_num = self.update_epoch()
Expand Down Expand Up @@ -1326,8 +1330,12 @@ def train(self):
torch.cuda.set_device(self.local_rank)
print("====================broadcasting parameters")
model_params = [self.model.state_dict()]
if self.has_central_value:
model_params.append(self.central_value_net.state_dict())
dist.broadcast_object_list(model_params, 0)
self.model.load_state_dict(model_params[0])
if self.has_central_value:
self.central_value_net.load_state_dict(model_params[1])

while True:
epoch_num = self.update_epoch()
Expand Down

0 comments on commit 59d4c40

Please sign in to comment.