From 633bd4c795b7233b73a0a82781dfb6ff0ad3ebda Mon Sep 17 00:00:00 2001 From: Roman Date: Mon, 28 Oct 2024 22:19:22 -0700 Subject: [PATCH] make obs batch require grads as requirement for loss --- rl_games/algos_torch/a2c_continuous.py | 3 +++ rl_games/algos_torch/a2c_discrete.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index 417b1163..d4dc59ca 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -112,6 +112,9 @@ def calc_gradients(self, input_dict): lr_mul = 1.0 curr_e_clip = self.e_clip + # set requires_grad to True for gradient penalty loss + obs_batch = obs_batch.requires_grad_(True) + batch_dict = { 'is_train': True, 'prev_actions': actions_batch, diff --git a/rl_games/algos_torch/a2c_discrete.py b/rl_games/algos_torch/a2c_discrete.py index 684c107e..7a026dda 100644 --- a/rl_games/algos_torch/a2c_discrete.py +++ b/rl_games/algos_torch/a2c_discrete.py @@ -137,6 +137,9 @@ def calc_gradients(self, input_dict): lr_mul = 1.0 curr_e_clip = lr_mul * self.e_clip + # set requires_grad to True for gradient penalty loss + obs_batch = obs_batch.requires_grad_(True) + batch_dict = { 'is_train': True, 'prev_actions': actions_batch,