Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stage3: Use new torch grad accumulation hooks API #6773

Merged
merged 14 commits into from
Jan 3, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,6 @@ def overlapping_partition_gradients_reduce_epilogue(self):

def create_reduce_and_remove_grad_hooks(self):
print_rank_0(f'[Begin] Create gradient reduction hooks')
self.grad_accs = []
self.leaf_parameters = defaultdict(list)
for i, param_group in enumerate(self.fp16_groups):
for param in param_group:
Expand All @@ -1169,15 +1168,13 @@ def create_reduce_and_remove_grad_hooks(self):

#print(f"After all gather {param.device}, {param.shape}")
def wrapper(param):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]

@instrument_w_nvtx
def reduce_partition_and_remove_grads(*notneeded):
self.reduce_ready_partitions_and_remove_grads(param)

self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads))
self.grad_accs.append(grad_acc)
self._grad_acc_hooks.append(
param.register_post_accumulate_grad_hook(reduce_partition_and_remove_grads))
tjruwase marked this conversation as resolved.
Show resolved Hide resolved

#print(f"param grad fn {param.expand_as(param).grad_fn}")
if z3_leaf_parameter(param):
Expand Down