diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 54b3e1ef..e083b0b5 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -1047,8 +1047,12 @@ def train(self): torch.cuda.set_device(self.local_rank) print("====================broadcasting parameters") model_params = [self.model.state_dict()] + if self.has_central_value: + model_params.append(self.central_value_net.state_dict()) dist.broadcast_object_list(model_params, 0) self.model.load_state_dict(model_params[0]) + if self.has_central_value: + self.central_value_net.load_state_dict(model_params[1]) while True: epoch_num = self.update_epoch() @@ -1326,8 +1330,12 @@ def train(self): torch.cuda.set_device(self.local_rank) print("====================broadcasting parameters") model_params = [self.model.state_dict()] + if self.has_central_value: + model_params.append(self.central_value_net.state_dict()) dist.broadcast_object_list(model_params, 0) self.model.load_state_dict(model_params[0]) + if self.has_central_value: + self.central_value_net.load_state_dict(model_params[1]) while True: epoch_num = self.update_epoch()