diff --git a/rl_games/algos_torch/players.py b/rl_games/algos_torch/players.py index 3dfcdadd..69df7913 100644 --- a/rl_games/algos_torch/players.py +++ b/rl_games/algos_torch/players.py @@ -199,7 +199,7 @@ def __init__(self, params): ] obs_shape = self.obs_shape - self.normalize_input = False + self.normalize_input = self.config.get('normalize_input', False) config = { 'obs_dim': self.env_info["observation_space"].shape[0], 'action_dim': self.env_info["action_space"].shape[0], @@ -229,8 +229,9 @@ def restore(self, fn): def get_action(self, obs, is_deterministic=False): if self.has_batch_dimension == False: obs = unsqueeze_obs(obs) + obs = self.model.norm_obs(obs) dist = self.model.actor(obs) - actions = dist.sample() if is_deterministic else dist.mean + actions = dist.sample() if not is_deterministic else dist.mean actions = actions.clamp(*self.action_range).to(self.device) if self.has_batch_dimension == False: actions = torch.squeeze(actions.detach()) diff --git a/rl_games/algos_torch/sac_agent.py b/rl_games/algos_torch/sac_agent.py index 4dc31a51..909ae3e5 100644 --- a/rl_games/algos_torch/sac_agent.py +++ b/rl_games/algos_torch/sac_agent.py @@ -211,6 +211,8 @@ def get_weights(self): state = {'actor': self.model.sac_network.actor.state_dict(), 'critic': self.model.sac_network.critic.state_dict(), 'critic_target': self.model.sac_network.critic_target.state_dict()} + if self.normalize_input: + state['running_mean_std'] = self.model.running_mean_std.state_dict() return state def save(self, fn): diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 7263e1f1..bda3d38f 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -1320,6 +1320,7 @@ def train(self): self.curr_frames = self.batch_size_envs if self.multi_gpu: + torch.cuda.set_device(self.local_rank) print("====================broadcasting parameters") model_params = [self.model.state_dict()] dist.broadcast_object_list(model_params, 0)