Skip to content

Commit

Permalink
Adjust loss to work with apply_masks function (similar to all other l…
Browse files Browse the repository at this point in the history
…osses)
  • Loading branch information
romaf5 committed Oct 29, 2024
1 parent 85e849d commit 32c47df
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions rl_games/common/common_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import math


def grad_penalty_loss(obs_batch, actions_log_prob_batch):
grad_log_prob = torch.autograd.grad(actions_log_prob_batch.sum(), obs_batch, create_graph=True)[0]
gradient_penalty_loss = torch.sum(torch.square(grad_log_prob), dim=-1).mean()
def grad_penalty_loss(obs_batch, actions_log_probs):
gradient = torch.autograd.grad(actions_log_probs.sum(), obs_batch, create_graph=True)[0]
gradient_penalty_loss = torch.sum(gradient ** 2, dim=-1)
return gradient_penalty_loss


Expand Down

0 comments on commit 32c47df

Please sign in to comment.