Skip to content

Commit df71c76

Browse files
authored
Fix memory leak from _hp_mapping (#5643)
See #5496 I don't really know if this is a good solution
1 parent b3767d0 commit df71c76

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

deepspeed/runtime/bf16_optimizer.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@
2626
setattr(sys.modules[__name__], 'fragment_address', fragment_address)
2727

2828

29+
def print_rank_0(message, debug=False, force=False):
30+
if dist.get_rank() == 0 and (debug or force):
31+
print(message)
32+
33+
2934
class BF16_Optimizer(ZeROOptimizer):
3035

3136
def __init__(self,
@@ -92,7 +97,16 @@ def __init__(self,
9297
if self.using_real_optimizer:
9398
self._setup_for_real_optimizer()
9499

95-
see_memory_usage('end bf16_optimizer', force=True)
100+
see_memory_usage('end bf16_ optimizer', force=True)
101+
102+
def destroy(self):
103+
for i, _ in enumerate(self.optimizer.param_groups):
104+
for p in self.bf16_groups[i]:
105+
if getattr(p, '_hp_mapping', None):
106+
p._hp_mapping = None
107+
for hook in self._grad_acc_hooks:
108+
hook.remove()
109+
print_rank_0("Removed grad acc hooks")
96110

97111
def _configure_moe_settings(self):
98112
assert any(
@@ -187,6 +201,7 @@ def _setup_for_real_optimizer(self):
187201
self.initialize_optimizer_states()
188202
see_memory_usage('end initialize_optimizer', force=True)
189203

204+
self._grad_acc_hooks = []
190205
if self.immediate_grad_update:
191206
self.create_grad_acc_hooks()
192207

@@ -541,7 +556,7 @@ def wrapper(param, i, j):
541556
def accumulate_hp_grads_and_remove_lp(*notneeded):
542557
self.accumulate_hp_grads_and_remove_lp(param, i, j)
543558

544-
grad_acc.register_hook(accumulate_hp_grads_and_remove_lp)
559+
self._grad_acc_hooks.append(grad_acc.register_hook(accumulate_hp_grads_and_remove_lp))
545560
self.grad_accs.append(grad_acc)
546561

547562
wrapper(param, i, j)

deepspeed/runtime/zero/stage_1_and_2.py

+4
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,10 @@ def __init__(self,
552552
self._param_slice_mappings = self._create_param_mapping()
553553

554554
def destroy(self):
555+
for i, _ in enumerate(self.optimizer.param_groups):
556+
for p in self.bit16_groups[i]:
557+
if getattr(p, '_hp_mapping', None):
558+
p._hp_mapping = None
555559
for hook in self._grad_acc_hooks:
556560
hook.remove()
557561
self.print_rank_0("Removed grad acc hooks")

0 commit comments

Comments
 (0)