diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index e5299949fcf6..493106e93239 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -98,8 +98,8 @@ def move_to_cpu(tensor_list): tensor.data = tensor.data.cpu() -def get_all_parameters(sub_module): - return itertools.chain(sub_module.named_parameters(recurse=False), +def get_all_parameters(sub_module, recurse=False): + return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters()) @@ -1037,13 +1037,19 @@ def setup_zero_stage3_hooks(self): self.hierarchy = 0 self._register_hooks_recursively(self.module) + #reset step at the beginning of forward + def _pre_forward_hook(module, *args): + self.param_coordinator.reset_step() + #reset step if in inference mode def _end_of_forward_hook(module, *args): if not torch._C.is_grad_enabled(): self.param_coordinator.reset_step() + #likely one of them should be enough but just to be safe self.module.register_forward_hook(_end_of_forward_hook) + self.module.register_forward_pre_hook(_pre_forward_hook) def persistent_parameters(self): persistent_params = []