diff --git a/rl_games/algos_torch/players.py b/rl_games/algos_torch/players.py index 2f82519b..69df7913 100644 --- a/rl_games/algos_torch/players.py +++ b/rl_games/algos_torch/players.py @@ -199,7 +199,7 @@ def __init__(self, params): ] obs_shape = self.obs_shape - self.normalize_input = False + self.normalize_input = self.config.get('normalize_input', False) config = { 'obs_dim': self.env_info["observation_space"].shape[0], 'action_dim': self.env_info["action_space"].shape[0], @@ -229,6 +229,7 @@ def restore(self, fn): def get_action(self, obs, is_deterministic=False): if self.has_batch_dimension == False: obs = unsqueeze_obs(obs) + obs = self.model.norm_obs(obs) dist = self.model.actor(obs) actions = dist.sample() if not is_deterministic else dist.mean actions = actions.clamp(*self.action_range).to(self.device) diff --git a/rl_games/algos_torch/sac_agent.py b/rl_games/algos_torch/sac_agent.py index dad8de0c..fd79fb7a 100644 --- a/rl_games/algos_torch/sac_agent.py +++ b/rl_games/algos_torch/sac_agent.py @@ -208,6 +208,8 @@ def get_weights(self): state = {'actor': self.model.sac_network.actor.state_dict(), 'critic': self.model.sac_network.critic.state_dict(), 'critic_target': self.model.sac_network.critic_target.state_dict()} + if self.normalize_input: + state['running_mean_std'] = self.model.running_mean_std.state_dict() return state def save(self, fn):