Skip to content

Commit

Permalink
Fix SAC with input normalization (#291)
Browse files Browse the repository at this point in the history
* If self.normalize_input is True in SACAgent class, add the weights of the running_mean_std layer in get_weights method

* Allow getting normalize_input from config and use self.model.norm_obs in get_action method

---------

Co-authored-by: Lukas Linauer <[email protected]>
  • Loading branch information
llinauer and lukaslinauer authored Jun 25, 2024
1 parent dec7275 commit 07043a3
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
3 changes: 2 additions & 1 deletion rl_games/algos_torch/players.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def __init__(self, params):
]

obs_shape = self.obs_shape
self.normalize_input = False
self.normalize_input = self.config.get('normalize_input', False)
config = {
'obs_dim': self.env_info["observation_space"].shape[0],
'action_dim': self.env_info["action_space"].shape[0],
Expand Down Expand Up @@ -229,6 +229,7 @@ def restore(self, fn):
def get_action(self, obs, is_deterministic=False):
if self.has_batch_dimension == False:
obs = unsqueeze_obs(obs)
obs = self.model.norm_obs(obs)
dist = self.model.actor(obs)
actions = dist.sample() if not is_deterministic else dist.mean
actions = actions.clamp(*self.action_range).to(self.device)
Expand Down
2 changes: 2 additions & 0 deletions rl_games/algos_torch/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ def get_weights(self):
state = {'actor': self.model.sac_network.actor.state_dict(),
'critic': self.model.sac_network.critic.state_dict(),
'critic_target': self.model.sac_network.critic_target.state_dict()}
if self.normalize_input:
state['running_mean_std'] = self.model.running_mean_std.state_dict()
return state

def save(self, fn):
Expand Down

0 comments on commit 07043a3

Please sign in to comment.