diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 00ce892f62ca..3112fb737aff 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -343,6 +343,10 @@ def __init__( self.flatten = util_ops.flatten self.unflatten = util_ops.unflatten + def destroy(self): + if self.optimizer is not None and hasattr(self.optimizer, 'destroy'): + self.optimizer.destroy() + def _get_model_parameters(self): if self.autotuning_profile_model_info(): self.autotuning_model_info = {} diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 3f4768163a2c..e963ef643677 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -359,7 +359,12 @@ def __init__(self, self.persistent_parameters = self.persistent_parameters() + self.forward_hooks = [] + self.backward_hooks = [] self.setup_zero_stage3_hooks() + print_rank_0( + f'Created module hooks: forward = {len(self.forward_hooks)}, backward = {len(self.backward_hooks)}', + force=False) #resetting ds_tensor just in case parameters have been changed after initialization #example .half() or .to() @@ -526,6 +531,23 @@ def __init__(self, if dist.get_rank(group=self.dp_process_group) == 0: see_memory_usage(f"After initializing ZeRO optimizer", force=True) + def destroy(self): + self._remove_module_hooks() + + def _remove_module_hooks(self): + num_forward_hooks = len(self.forward_hooks) + num_backward_hooks = len(self.backward_hooks) + + for hook in self.forward_hooks: + hook.remove() + + for hook in self.backward_hooks: + hook.remove() + + print_rank_0( + f'Deleted module hooks: forward = {num_forward_hooks}, backward = {num_backward_hooks}', + force=False) + def _setup_for_real_optimizer(self): see_memory_usage("Before creating fp32 partitions", force=False) self._create_fp32_partitions() @@ -1201,15 +1223,20 @@ def _run_after_backward_function(sub_module): inputs) # Pre forward hook - module.register_forward_pre_hook(_pre_forward_module_hook) + self.forward_hooks.append( + module.register_forward_pre_hook(_pre_forward_module_hook)) + # Post forward hook - module.register_forward_hook(_post_forward_module_hook) + self.forward_hooks.append( + module.register_forward_hook(_post_forward_module_hook)) # Pre backward hook - module.register_forward_hook(_pre_backward_module_hook) + self.backward_hooks.append( + module.register_forward_hook(_pre_backward_module_hook)) # post backward hook - module.register_forward_pre_hook(_post_backward_module_hook) + self.backward_hooks.append( + module.register_forward_pre_hook(_post_backward_module_hook)) @torch.no_grad() def pre_sub_module_forward_function(self, sub_module):