diff --git a/rl_games/algos_torch/sac_agent.py b/rl_games/algos_torch/sac_agent.py index 3697a7f3..dad8de0c 100644 --- a/rl_games/algos_torch/sac_agent.py +++ b/rl_games/algos_torch/sac_agent.py @@ -493,8 +493,9 @@ def play_steps(self, random_exploration = False): if isinstance(next_obs, dict): next_obs_processed = next_obs['obs'] - - self.obs = next_obs.clone() + self.obs = next_obs_processed.clone() + else: + self.obs = next_obs.clone() rewards = self.rewards_shaper(rewards)