Skip to content
8 changes: 8 additions & 0 deletions csrc/adam/cpu_adam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,11 +672,19 @@ int ds_adam_step_plus_copy(int optimizer_id,
return 0;
}

int destroy_adam_optimizer(int optimizer_id)
{
s_optimizers.erase(optimizer_id);

return 0;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)");
m.def("adam_update_copy",
&ds_adam_step_plus_copy,
"DeepSpeed CPU Adam update and param copy (C++)");
m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)");
m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)");
}
5 changes: 5 additions & 0 deletions deepspeed/ops/adam/cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def __init__(self,
weight_decay,
adamw_mode)

def __del__(self):
# need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize
# is used multiple times in the same process (notebook or pytest worker)
self.ds_opt_adam.destroy_adam(self.opt_id)

def __setstate__(self, state):
super(DeepSpeedCPUAdam, self).__setstate__(state)
for group in self.param_groups:
Expand Down