Skip to content

Commit

Permalink
Merge from master.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Jun 27, 2024
2 parents 6a91bd3 + 07043a3 commit 46c8c48
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
5 changes: 3 additions & 2 deletions rl_games/algos_torch/players.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def __init__(self, params):
]

obs_shape = self.obs_shape
self.normalize_input = False
self.normalize_input = self.config.get('normalize_input', False)
config = {
'obs_dim': self.env_info["observation_space"].shape[0],
'action_dim': self.env_info["action_space"].shape[0],
Expand Down Expand Up @@ -229,8 +229,9 @@ def restore(self, fn):
def get_action(self, obs, is_deterministic=False):
if self.has_batch_dimension == False:
obs = unsqueeze_obs(obs)
obs = self.model.norm_obs(obs)
dist = self.model.actor(obs)
actions = dist.sample() if is_deterministic else dist.mean
actions = dist.sample() if not is_deterministic else dist.mean
actions = actions.clamp(*self.action_range).to(self.device)
if self.has_batch_dimension == False:
actions = torch.squeeze(actions.detach())
Expand Down
2 changes: 2 additions & 0 deletions rl_games/algos_torch/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ def get_weights(self):
state = {'actor': self.model.sac_network.actor.state_dict(),
'critic': self.model.sac_network.critic.state_dict(),
'critic_target': self.model.sac_network.critic_target.state_dict()}
if self.normalize_input:
state['running_mean_std'] = self.model.running_mean_std.state_dict()
return state

def save(self, fn):
Expand Down
1 change: 1 addition & 0 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,7 @@ def train(self):
self.curr_frames = self.batch_size_envs

if self.multi_gpu:
torch.cuda.set_device(self.local_rank)
print("====================broadcasting parameters")
model_params = [self.model.state_dict()]
dist.broadcast_object_list(model_params, 0)
Expand Down

0 comments on commit 46c8c48

Please sign in to comment.