From 3907e0e610a7a364b7952e2cdb54e0545d7235c9 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 23 May 2022 08:56:30 -0700 Subject: [PATCH] Refacotor optimizer step logic --- rl_games/algos_torch/a2c_continuous.py | 2 +- rl_games/algos_torch/a2c_discrete.py | 2 +- rl_games/common/a2c_common.py | 20 +++++++++----------- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index 7cfd07d3..489e8e0e 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -135,7 +135,7 @@ def calc_gradients(self, input_dict): self.scaler.scale(loss).backward() #TODO: Refactor this ugliest code of they year - self.trancate_gradients() + self.trancate_gradients_and_step() with torch.no_grad(): reduce_kl = rnn_masks is None diff --git a/rl_games/algos_torch/a2c_discrete.py b/rl_games/algos_torch/a2c_discrete.py index 26234236..6f912788 100644 --- a/rl_games/algos_torch/a2c_discrete.py +++ b/rl_games/algos_torch/a2c_discrete.py @@ -159,7 +159,7 @@ def calc_gradients(self, input_dict): param.grad = None self.scaler.scale(loss).backward() - self.trancate_gradients() + self.trancate_gradients_and_step() with torch.no_grad(): kl_dist = 0.5 * ((old_action_log_probs_batch - action_log_probs)**2) diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 9708aadc..428902e2 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -252,18 +252,16 @@ def __init__(self, base_name, params): # soft augmentation not yet supported assert not self.has_soft_aug - def trancate_gradients(self): + def trancate_gradients_and_step(self): + if self.multi_gpu: + self.optimizer.synchronize() + if self.truncate_grads: - if self.multi_gpu: - self.optimizer.synchronize() - self.scaler.unscale_(self.optimizer) - nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm) - with self.optimizer.skip_synchronize(): - self.scaler.step(self.optimizer) - self.scaler.update() - else: - self.scaler.unscale_(self.optimizer) - nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm) + self.scaler.unscale_(self.optimizer) + nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm) + + if self.multi_gpu: + with self.optimizer.skip_synchronize(): self.scaler.step(self.optimizer) self.scaler.update() else: