Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def move_to_cpu(tensor_list):
tensor.data = tensor.data.cpu()


def get_all_parameters(sub_module):
return itertools.chain(sub_module.named_parameters(recurse=False),
def get_all_parameters(sub_module, recurse=False):
return itertools.chain(sub_module.named_parameters(recurse=recurse),
sub_module.ds_external_parameters())


Expand Down Expand Up @@ -1037,13 +1037,19 @@ def setup_zero_stage3_hooks(self):
self.hierarchy = 0
self._register_hooks_recursively(self.module)

#reset step at the beginning of forward
def _pre_forward_hook(module, *args):
self.param_coordinator.reset_step()

#reset step if in inference mode
def _end_of_forward_hook(module, *args):

if not torch._C.is_grad_enabled():
self.param_coordinator.reset_step()

#likely one of them should be enough but just to be safe
self.module.register_forward_hook(_end_of_forward_hook)
self.module.register_forward_pre_hook(_pre_forward_hook)

def persistent_parameters(self):
persistent_params = []
Expand Down