Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refacotor optimizer step logic #163

Merged
merged 1 commit into from
May 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def calc_gradients(self, input_dict):

self.scaler.scale(loss).backward()
#TODO: Refactor this ugliest code of they year
self.trancate_gradients()
self.trancate_gradients_and_step()

with torch.no_grad():
reduce_kl = rnn_masks is None
Expand Down
2 changes: 1 addition & 1 deletion rl_games/algos_torch/a2c_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def calc_gradients(self, input_dict):
param.grad = None

self.scaler.scale(loss).backward()
self.trancate_gradients()
self.trancate_gradients_and_step()

with torch.no_grad():
kl_dist = 0.5 * ((old_action_log_probs_batch - action_log_probs)**2)
Expand Down
20 changes: 9 additions & 11 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,18 +252,16 @@ def __init__(self, base_name, params):
# soft augmentation not yet supported
assert not self.has_soft_aug

def trancate_gradients(self):
def trancate_gradients_and_step(self):
if self.multi_gpu:
self.optimizer.synchronize()

if self.truncate_grads:
if self.multi_gpu:
self.optimizer.synchronize()
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)
with self.optimizer.skip_synchronize():
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)

if self.multi_gpu:
with self.optimizer.skip_synchronize():
self.scaler.step(self.optimizer)
self.scaler.update()
else:
Expand Down