diff --git a/rl_games/common/common_losses.py b/rl_games/common/common_losses.py index c2ebaaa8..fbc887e6 100644 --- a/rl_games/common/common_losses.py +++ b/rl_games/common/common_losses.py @@ -3,9 +3,9 @@ import math -def grad_penalty_loss(obs_batch, actions_log_prob_batch): - grad_log_prob = torch.autograd.grad(actions_log_prob_batch.sum(), obs_batch, create_graph=True)[0] - gradient_penalty_loss = torch.sum(torch.square(grad_log_prob), dim=-1).mean() +def grad_penalty_loss(obs_batch, actions_log_probs): + gradient = torch.autograd.grad(actions_log_probs.sum(), obs_batch, create_graph=True)[0] + gradient_penalty_loss = torch.sum(gradient ** 2, dim=-1) return gradient_penalty_loss