diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index 417b1163..d4dc59ca 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -112,6 +112,9 @@ def calc_gradients(self, input_dict): lr_mul = 1.0 curr_e_clip = self.e_clip + # set requires_grad to True for gradient penalty loss + obs_batch = obs_batch.requires_grad_(True) + batch_dict = { 'is_train': True, 'prev_actions': actions_batch, diff --git a/rl_games/algos_torch/a2c_discrete.py b/rl_games/algos_torch/a2c_discrete.py index 684c107e..7a026dda 100644 --- a/rl_games/algos_torch/a2c_discrete.py +++ b/rl_games/algos_torch/a2c_discrete.py @@ -137,6 +137,9 @@ def calc_gradients(self, input_dict): lr_mul = 1.0 curr_e_clip = lr_mul * self.e_clip + # set requires_grad to True for gradient penalty loss + obs_batch = obs_batch.requires_grad_(True) + batch_dict = { 'is_train': True, 'prev_actions': actions_batch,