Skip to content

Commit

Permalink
Refacotor optimizer step logic (#163)
Browse files Browse the repository at this point in the history
Co-authored-by: Costa Huang <[email protected]>
  • Loading branch information
vwxyzjn and Costa Huang authored May 23, 2022
1 parent a320613 commit 86f5e82
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 13 deletions.
2 changes: 1 addition & 1 deletion rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def calc_gradients(self, input_dict):

self.scaler.scale(loss).backward()
#TODO: Refactor this ugliest code of they year
self.trancate_gradients()
self.trancate_gradients_and_step()

with torch.no_grad():
reduce_kl = rnn_masks is None
Expand Down
2 changes: 1 addition & 1 deletion rl_games/algos_torch/a2c_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def calc_gradients(self, input_dict):
param.grad = None

self.scaler.scale(loss).backward()
self.trancate_gradients()
self.trancate_gradients_and_step()

with torch.no_grad():
kl_dist = 0.5 * ((old_action_log_probs_batch - action_log_probs)**2)
Expand Down
20 changes: 9 additions & 11 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,18 +254,16 @@ def __init__(self, base_name, params):
# soft augmentation not yet supported
assert not self.has_soft_aug

def trancate_gradients(self):
def trancate_gradients_and_step(self):
if self.multi_gpu:
self.optimizer.synchronize()

if self.truncate_grads:
if self.multi_gpu:
self.optimizer.synchronize()
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)
with self.optimizer.skip_synchronize():
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)

if self.multi_gpu:
with self.optimizer.skip_synchronize():
self.scaler.step(self.optimizer)
self.scaler.update()
else:
Expand Down

0 comments on commit 86f5e82

Please sign in to comment.