Skip to content

Commit

Permalink
Add gradient penalty loss based on LCP
Browse files Browse the repository at this point in the history
  • Loading branch information
romaf5 committed Oct 29, 2024
1 parent 90af59b commit 85e849d
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 7 deletions.
12 changes: 9 additions & 3 deletions rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,16 @@ def calc_gradients(self, input_dict):
b_loss = self.bound_loss(mu)
else:
b_loss = torch.zeros(1, device=self.ppo_device)
losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss , entropy.unsqueeze(1), b_loss.unsqueeze(1)], rnn_masks)
a_loss, c_loss, entropy, b_loss = losses[0], losses[1], losses[2], losses[3]

loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef + b_loss * self.bounds_loss_coef
if self.gradient_penalty_coef != 0.0:
gradient_penalty_loss = common_losses.grad_penalty_loss(obs_batch, action_log_probs)
else:
gradient_penalty_loss = torch.zeros(1, device=self.ppo_device)

losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss , entropy.unsqueeze(1), b_loss.unsqueeze(1), gradient_penalty_loss.unsqueeze(1)], rnn_masks)
a_loss, c_loss, entropy, b_loss, gradient_penalty_loss = losses[0], losses[1], losses[2], losses[3], losses[4]

loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef + b_loss * self.bounds_loss_coef + gradient_penalty_loss * self.gradient_penalty_coef
aux_loss = self.model.get_aux_loss()
self.aux_loss_dict = {}
if aux_loss is not None:
Expand Down
12 changes: 8 additions & 4 deletions rl_games/algos_torch/a2c_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,14 @@ def calc_gradients(self, input_dict):
else:
c_loss = torch.zeros(1, device=self.ppo_device)


losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss, entropy.unsqueeze(1)], rnn_masks)
a_loss, c_loss, entropy = losses[0], losses[1], losses[2]
loss = a_loss + 0.5 *c_loss * self.critic_coef - entropy * self.entropy_coef
if self.gradient_penalty_coef != 0.0:
gradient_penalty_loss = common_losses.grad_penalty_loss(obs_batch, action_log_probs)
else:
gradient_penalty_loss = torch.zeros(1, device=self.ppo_device)

losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss, entropy.unsqueeze(1), gradient_penalty_loss.unsqueeze(1)], rnn_masks)
a_loss, c_loss, entropy, gradient_penalty_loss = losses[0], losses[1], losses[2], losses[3]
loss = a_loss + 0.5 *c_loss * self.critic_coef - entropy * self.entropy_coef + gradient_penalty_loss * self.gradient_penalty_coef
aux_loss = self.model.get_aux_loss()
self.aux_loss_dict = {}
if aux_loss is not None:
Expand Down
1 change: 1 addition & 0 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def __init__(self, base_name, params):
os.makedirs(self.summaries_dir, exist_ok=True)

self.entropy_coef = self.config['entropy_coef']
self.gradient_penalty_coef = self.config.get('gradient_penalty_coef', 0.0)

if self.global_rank == 0:
writer = SummaryWriter(self.summaries_dir)
Expand Down
6 changes: 6 additions & 0 deletions rl_games/common/common_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
import math


def grad_penalty_loss(obs_batch, actions_log_prob_batch):
grad_log_prob = torch.autograd.grad(actions_log_prob_batch.sum(), obs_batch, create_graph=True)[0]
gradient_penalty_loss = torch.sum(torch.square(grad_log_prob), dim=-1).mean()
return gradient_penalty_loss


def critic_loss(model, value_preds_batch, values, curr_e_clip, return_batch, clip_value):
return default_critic_loss(value_preds_batch, values, curr_e_clip, return_batch, clip_value)
#return model.get_value_layer().loss(value_preds_batch=value_preds_batch, values=values, curr_e_clip=curr_e_clip, return_batch=return_batch, clip_value=clip_value)
Expand Down

0 comments on commit 85e849d

Please sign in to comment.