Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,10 @@ def __init__(
self.flatten = util_ops.flatten
self.unflatten = util_ops.unflatten

def destroy(self):
if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):
self.optimizer.destroy()

def _get_model_parameters(self):
if self.autotuning_profile_model_info():
self.autotuning_model_info = {}
Expand Down
35 changes: 31 additions & 4 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,12 @@ def __init__(self,

self.persistent_parameters = self.persistent_parameters()

self.forward_hooks = []
self.backward_hooks = []
self.setup_zero_stage3_hooks()
print_rank_0(
f'Created module hooks: forward = {len(self.forward_hooks)}, backward = {len(self.backward_hooks)}',
force=False)

#resetting ds_tensor just in case parameters have been changed after initialization
#example .half() or .to()
Expand Down Expand Up @@ -526,6 +531,23 @@ def __init__(self,
if dist.get_rank(group=self.dp_process_group) == 0:
see_memory_usage(f"After initializing ZeRO optimizer", force=True)

def destroy(self):
self._remove_module_hooks()

def _remove_module_hooks(self):
num_forward_hooks = len(self.forward_hooks)
num_backward_hooks = len(self.backward_hooks)

for hook in self.forward_hooks:
hook.remove()

for hook in self.backward_hooks:
hook.remove()

print_rank_0(
f'Deleted module hooks: forward = {num_forward_hooks}, backward = {num_backward_hooks}',
force=False)

def _setup_for_real_optimizer(self):
see_memory_usage("Before creating fp32 partitions", force=False)
self._create_fp32_partitions()
Expand Down Expand Up @@ -1201,15 +1223,20 @@ def _run_after_backward_function(sub_module):
inputs)

# Pre forward hook
module.register_forward_pre_hook(_pre_forward_module_hook)
self.forward_hooks.append(
module.register_forward_pre_hook(_pre_forward_module_hook))

# Post forward hook
module.register_forward_hook(_post_forward_module_hook)
self.forward_hooks.append(
module.register_forward_hook(_post_forward_module_hook))

# Pre backward hook
module.register_forward_hook(_pre_backward_module_hook)
self.backward_hooks.append(
module.register_forward_hook(_pre_backward_module_hook))

# post backward hook
module.register_forward_pre_hook(_post_backward_module_hook)
self.backward_hooks.append(
module.register_forward_pre_hook(_post_backward_module_hook))

@torch.no_grad()
def pre_sub_module_forward_function(self, sub_module):
Expand Down