diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index e93ea362..417b1163 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -147,10 +147,16 @@ def calc_gradients(self, input_dict): b_loss = self.bound_loss(mu) else: b_loss = torch.zeros(1, device=self.ppo_device) - losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss , entropy.unsqueeze(1), b_loss.unsqueeze(1)], rnn_masks) - a_loss, c_loss, entropy, b_loss = losses[0], losses[1], losses[2], losses[3] - loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef + b_loss * self.bounds_loss_coef + if self.gradient_penalty_coef != 0.0: + gradient_penalty_loss = common_losses.grad_penalty_loss(obs_batch, action_log_probs) + else: + gradient_penalty_loss = torch.zeros(1, device=self.ppo_device) + + losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss , entropy.unsqueeze(1), b_loss.unsqueeze(1), gradient_penalty_loss.unsqueeze(1)], rnn_masks) + a_loss, c_loss, entropy, b_loss, gradient_penalty_loss = losses[0], losses[1], losses[2], losses[3], losses[4] + + loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef + b_loss * self.bounds_loss_coef + gradient_penalty_loss * self.gradient_penalty_coef aux_loss = self.model.get_aux_loss() self.aux_loss_dict = {} if aux_loss is not None: diff --git a/rl_games/algos_torch/a2c_discrete.py b/rl_games/algos_torch/a2c_discrete.py index a69ba9bb..684c107e 100644 --- a/rl_games/algos_torch/a2c_discrete.py +++ b/rl_games/algos_torch/a2c_discrete.py @@ -166,10 +166,14 @@ def calc_gradients(self, input_dict): else: c_loss = torch.zeros(1, device=self.ppo_device) - - losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss, entropy.unsqueeze(1)], rnn_masks) - a_loss, c_loss, entropy = losses[0], losses[1], losses[2] - loss = a_loss + 0.5 *c_loss * self.critic_coef - entropy * self.entropy_coef + if self.gradient_penalty_coef != 0.0: + gradient_penalty_loss = common_losses.grad_penalty_loss(obs_batch, action_log_probs) + else: + gradient_penalty_loss = torch.zeros(1, device=self.ppo_device) + + losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss, entropy.unsqueeze(1), gradient_penalty_loss.unsqueeze(1)], rnn_masks) + a_loss, c_loss, entropy, gradient_penalty_loss = losses[0], losses[1], losses[2], losses[3] + loss = a_loss + 0.5 *c_loss * self.critic_coef - entropy * self.entropy_coef + gradient_penalty_loss * self.gradient_penalty_coef aux_loss = self.model.get_aux_loss() self.aux_loss_dict = {} if aux_loss is not None: diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 54a5cda1..f58c9354 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -282,6 +282,7 @@ def __init__(self, base_name, params): os.makedirs(self.summaries_dir, exist_ok=True) self.entropy_coef = self.config['entropy_coef'] + self.gradient_penalty_coef = self.config.get('gradient_penalty_coef', 0.0) if self.global_rank == 0: writer = SummaryWriter(self.summaries_dir) diff --git a/rl_games/common/common_losses.py b/rl_games/common/common_losses.py index 04c00644..c2ebaaa8 100644 --- a/rl_games/common/common_losses.py +++ b/rl_games/common/common_losses.py @@ -3,6 +3,12 @@ import math +def grad_penalty_loss(obs_batch, actions_log_prob_batch): + grad_log_prob = torch.autograd.grad(actions_log_prob_batch.sum(), obs_batch, create_graph=True)[0] + gradient_penalty_loss = torch.sum(torch.square(grad_log_prob), dim=-1).mean() + return gradient_penalty_loss + + def critic_loss(model, value_preds_batch, values, curr_e_clip, return_batch, clip_value): return default_critic_loss(value_preds_batch, values, curr_e_clip, return_batch, clip_value) #return model.get_value_layer().loss(value_preds_batch=value_preds_batch, values=values, curr_e_clip=curr_e_clip, return_batch=return_batch, clip_value=clip_value)