|
26 | 26 | setattr(sys.modules[__name__], 'fragment_address', fragment_address)
|
27 | 27 |
|
28 | 28 |
|
| 29 | +def print_rank_0(message, debug=False, force=False): |
| 30 | + if dist.get_rank() == 0 and (debug or force): |
| 31 | + print(message) |
| 32 | + |
| 33 | + |
29 | 34 | class BF16_Optimizer(ZeROOptimizer):
|
30 | 35 |
|
31 | 36 | def __init__(self,
|
@@ -92,7 +97,16 @@ def __init__(self,
|
92 | 97 | if self.using_real_optimizer:
|
93 | 98 | self._setup_for_real_optimizer()
|
94 | 99 |
|
95 |
| - see_memory_usage('end bf16_optimizer', force=True) |
| 100 | + see_memory_usage('end bf16_ optimizer', force=True) |
| 101 | + |
| 102 | + def destroy(self): |
| 103 | + for i, _ in enumerate(self.optimizer.param_groups): |
| 104 | + for p in self.bf16_groups[i]: |
| 105 | + if getattr(p, '_hp_mapping', None): |
| 106 | + p._hp_mapping = None |
| 107 | + for hook in self._grad_acc_hooks: |
| 108 | + hook.remove() |
| 109 | + print_rank_0("Removed grad acc hooks") |
96 | 110 |
|
97 | 111 | def _configure_moe_settings(self):
|
98 | 112 | assert any(
|
@@ -187,6 +201,7 @@ def _setup_for_real_optimizer(self):
|
187 | 201 | self.initialize_optimizer_states()
|
188 | 202 | see_memory_usage('end initialize_optimizer', force=True)
|
189 | 203 |
|
| 204 | + self._grad_acc_hooks = [] |
190 | 205 | if self.immediate_grad_update:
|
191 | 206 | self.create_grad_acc_hooks()
|
192 | 207 |
|
@@ -541,7 +556,7 @@ def wrapper(param, i, j):
|
541 | 556 | def accumulate_hp_grads_and_remove_lp(*notneeded):
|
542 | 557 | self.accumulate_hp_grads_and_remove_lp(param, i, j)
|
543 | 558 |
|
544 |
| - grad_acc.register_hook(accumulate_hp_grads_and_remove_lp) |
| 559 | + self._grad_acc_hooks.append(grad_acc.register_hook(accumulate_hp_grads_and_remove_lp)) |
545 | 560 | self.grad_accs.append(grad_acc)
|
546 | 561 |
|
547 | 562 | wrapper(param, i, j)
|
|
0 commit comments