diff --git a/rl_games/algos_torch/players.py b/rl_games/algos_torch/players.py index 3dfcdadd..2f82519b 100644 --- a/rl_games/algos_torch/players.py +++ b/rl_games/algos_torch/players.py @@ -230,7 +230,7 @@ def get_action(self, obs, is_deterministic=False): if self.has_batch_dimension == False: obs = unsqueeze_obs(obs) dist = self.model.actor(obs) - actions = dist.sample() if is_deterministic else dist.mean + actions = dist.sample() if not is_deterministic else dist.mean actions = actions.clamp(*self.action_range).to(self.device) if self.has_batch_dimension == False: actions = torch.squeeze(actions.detach())