From 32c47df8eb2a126c9c3e6e1cd7c2547e0b9ee619 Mon Sep 17 00:00:00 2001 From: Roman Date: Mon, 28 Oct 2024 21:37:45 -0700 Subject: [PATCH] Adjust loss to work with apply_masks function (similar to all other losses) --- rl_games/common/common_losses.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rl_games/common/common_losses.py b/rl_games/common/common_losses.py index c2ebaaa8..fbc887e6 100644 --- a/rl_games/common/common_losses.py +++ b/rl_games/common/common_losses.py @@ -3,9 +3,9 @@ import math -def grad_penalty_loss(obs_batch, actions_log_prob_batch): - grad_log_prob = torch.autograd.grad(actions_log_prob_batch.sum(), obs_batch, create_graph=True)[0] - gradient_penalty_loss = torch.sum(torch.square(grad_log_prob), dim=-1).mean() +def grad_penalty_loss(obs_batch, actions_log_probs): + gradient = torch.autograd.grad(actions_log_probs.sum(), obs_batch, create_graph=True)[0] + gradient_penalty_loss = torch.sum(gradient ** 2, dim=-1) return gradient_penalty_loss