diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index 12b24a0e..7dafb2d5 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -95,7 +95,7 @@ def restore_central_value_function(self, fn): def get_masked_action_values(self, obs, action_masks): assert False -# @torch.compile() #(mode='max-autotune') + @torch.compile() #(mode='max-autotune') def calc_losses(self, actor_loss_func, old_action_log_probs_batch, action_log_probs, advantage, curr_e_clip, value_preds_batch, values, return_batch, mu, entropy, rnn_masks): a_loss = actor_loss_func(old_action_log_probs_batch, action_log_probs, advantage, self.ppo, curr_e_clip)