From 3dc1d5c783648a096792edd9b5974a8979460cf9 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 26 Mar 2021 21:29:15 -0700 Subject: [PATCH 1/5] fix cpu_adam mem leak --- csrc/adam/cpu_adam.cpp | 10 ++++++++++ deepspeed/ops/adam/cpu_adam.py | 4 ++++ 2 files changed, 14 insertions(+) diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp index d425dc3169ef..bb6bd13c0c77 100644 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -672,6 +672,15 @@ int ds_adam_step_plus_copy(int optimizer_id, return 0; } +int destroy_adam_optimizer(int optimizer_id) +{ + // std::cout << "Adam Optimizer #" << optimizer_id + // << " is destroyed." << std::endl; + s_optimizers.erase(optimizer_id); + + return 0; +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)"); @@ -679,4 +688,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) &ds_adam_step_plus_copy, "DeepSpeed CPU Adam update and param copy (C++)"); m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)"); + m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)"); } diff --git a/deepspeed/ops/adam/cpu_adam.py b/deepspeed/ops/adam/cpu_adam.py index 7977d232b1fa..418e0ccfcd81 100755 --- a/deepspeed/ops/adam/cpu_adam.py +++ b/deepspeed/ops/adam/cpu_adam.py @@ -85,6 +85,10 @@ def __init__(self, weight_decay, adamw_mode) + def __del__(self): + # need to destroy the C++ object explicitly avoid memory leak when deepspeed is re-used in the same process + self.ds_opt_adam.destroy_adam(self.opt_id) + def __setstate__(self, state): super(DeepSpeedCPUAdam, self).__setstate__(state) for group in self.param_groups: From 1990715c7705947bf4c1511e6be97477578b2d8c Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 26 Mar 2021 22:09:09 -0700 Subject: [PATCH 2/5] prose --- deepspeed/ops/adam/cpu_adam.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepspeed/ops/adam/cpu_adam.py b/deepspeed/ops/adam/cpu_adam.py index 418e0ccfcd81..35eeedb86b5d 100755 --- a/deepspeed/ops/adam/cpu_adam.py +++ b/deepspeed/ops/adam/cpu_adam.py @@ -86,7 +86,8 @@ def __init__(self, adamw_mode) def __del__(self): - # need to destroy the C++ object explicitly avoid memory leak when deepspeed is re-used in the same process + # need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize + # is used multiple times in the same process (notebook or pytest worker) self.ds_opt_adam.destroy_adam(self.opt_id) def __setstate__(self, state): From 5002e499f63ba51ab48d744bebdef18e54de9746 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Sat, 27 Mar 2021 17:04:11 +0000 Subject: [PATCH 3/5] receivet the optimizer Id from client side when running step (Default=0) --- deepspeed/ops/adam/cpu_adam.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/ops/adam/cpu_adam.py b/deepspeed/ops/adam/cpu_adam.py index 35eeedb86b5d..dd12e6238aab 100755 --- a/deepspeed/ops/adam/cpu_adam.py +++ b/deepspeed/ops/adam/cpu_adam.py @@ -96,7 +96,7 @@ def __setstate__(self, state): group.setdefault('amsgrad', False) @torch.no_grad() - def step(self, closure=None, fp16_param_groups=None): + def step(self, optId=0, closure=None, fp16_param_groups=None): """Update the model parameters. .. note:: @@ -147,7 +147,7 @@ def step(self, closure=None, fp16_param_groups=None): if fp16_param_groups is not None: self.ds_opt_adam.adam_update_copy( - self.opt_id, + optId, state['step'], group['lr'], beta1, @@ -161,7 +161,7 @@ def step(self, closure=None, fp16_param_groups=None): state['exp_avg_sq'], fp16_param_groups[group_id][param_id].data) else: - self.ds_opt_adam.adam_update(self.opt_id, + self.ds_opt_adam.adam_update(optId, state['step'], group['lr'], beta1, From 8a8244c2bd3225a37aabd068d188a0f4e90d702d Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 27 Mar 2021 11:42:22 -0700 Subject: [PATCH 4/5] remove comment --- csrc/adam/cpu_adam.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp index bb6bd13c0c77..6726b895f12c 100644 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -674,8 +674,6 @@ int ds_adam_step_plus_copy(int optimizer_id, int destroy_adam_optimizer(int optimizer_id) { - // std::cout << "Adam Optimizer #" << optimizer_id - // << " is destroyed." << std::endl; s_optimizers.erase(optimizer_id); return 0; From d7b4f4ed3628f5416f71e8f92e9beb1c6397a038 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Sun, 28 Mar 2021 21:36:29 +0000 Subject: [PATCH 5/5] revert changes for optimizer id --- deepspeed/ops/adam/cpu_adam.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/ops/adam/cpu_adam.py b/deepspeed/ops/adam/cpu_adam.py index dd12e6238aab..35eeedb86b5d 100755 --- a/deepspeed/ops/adam/cpu_adam.py +++ b/deepspeed/ops/adam/cpu_adam.py @@ -96,7 +96,7 @@ def __setstate__(self, state): group.setdefault('amsgrad', False) @torch.no_grad() - def step(self, optId=0, closure=None, fp16_param_groups=None): + def step(self, closure=None, fp16_param_groups=None): """Update the model parameters. .. note:: @@ -147,7 +147,7 @@ def step(self, optId=0, closure=None, fp16_param_groups=None): if fp16_param_groups is not None: self.ds_opt_adam.adam_update_copy( - optId, + self.opt_id, state['step'], group['lr'], beta1, @@ -161,7 +161,7 @@ def step(self, optId=0, closure=None, fp16_param_groups=None): state['exp_avg_sq'], fp16_param_groups[group_id][param_id].data) else: - self.ds_opt_adam.adam_update(optId, + self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1,