Skip to content

Commit

Permalink
Forgot to return func
Browse files Browse the repository at this point in the history
Signed-off-by: Tim Moon <[email protected]>
  • Loading branch information
timmoon10 committed Jul 10, 2023
1 parent 0635aac commit 9611f8b
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions nemo/core/optim/distributed_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,20 @@ def __init__(self, params, disable_distributed_parameters=False, **kwargs):

def _fp32_register_post_backward_hooks(self):
"""Attach hooks for FP32 gradients"""

# Helper function to avoid issues with late binding closures
def make_post_backward_hook(param):
def post_backward_hook(*unused):
self._fp32_optim_grad_sync_needed = True
with torch.no_grad():
if hasattr(param, 'main_grad'):
if hasattr(param, 'main_grad'):
with torch.no_grad():
if param.grad is not None:
param.main_grad += param.grad.detach()
param.main_grad += param.grad
param.grad = None

return post_backward_hook

# Construct hooks and register with params
self._fp32_grad_accs = []
for param in self._fp32_optim_main_params.keys():
param_tmp = param.expand_as(param)
Expand Down

0 comments on commit 9611f8b

Please sign in to comment.