Skip to content

Commit cfc6ed3

Browse files
authored
bf16_optimizer: fixes to different grad acc dtype (#6485)
- fix step function to cast to FP32 before step in case of different gradient accumulation data type - remove redundatn function initialize_optimizer_states()
1 parent 9b7fc54 commit cfc6ed3

File tree

1 file changed

+10
-23
lines changed

1 file changed

+10
-23
lines changed

deepspeed/runtime/bf16_optimizer.py

+10-23
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,6 @@ def _setup_for_real_optimizer(self):
197197

198198
see_memory_usage(f'after initializing group {i}', force=True)
199199

200-
see_memory_usage('before initialize_optimizer', force=True)
201-
self.initialize_optimizer_states()
202-
see_memory_usage('end initialize_optimizer', force=True)
203-
204200
self._grad_acc_hooks = []
205201
if self.immediate_grad_update:
206202
self.create_grad_acc_hooks()
@@ -252,25 +248,6 @@ def _lazy_init_hp_params_optimizer_state(self):
252248
self.optimizer.state)
253249
self._hp_optimizer_states_linked = True
254250

255-
def initialize_optimizer_states(self):
256-
"""Take an optimizer step with zero-valued gradients to allocate internal
257-
optimizer state.
258-
259-
This helps prevent memory fragmentation by allocating optimizer state at the
260-
beginning of training instead of after activations have been allocated.
261-
"""
262-
for param_partition, grad_partition in zip(self.fp32_groups_flat_partition,
263-
self.fp32_groups_gradient_flat_partition):
264-
# In case of grad acc dtype different than FP32, need to cast to high precision.
265-
param_partition.grad = grad_partition.to(
266-
param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition
267-
268-
if self.grad_acc_dtype is not torch.float32:
269-
for param_partition in self.fp32_groups_flat_partition:
270-
param_partition.grad = None
271-
272-
self.clear_hp_grads()
273-
274251
def _split_flat_tensor(self, flat_tensor, num_elem_list):
275252
assert sum(num_elem_list) <= flat_tensor.numel()
276253
tensor_list = []
@@ -317,8 +294,18 @@ def step(self, closure=None):
317294
mpu=self.mpu,
318295
use_graph=self.graph_harvesting)
319296

297+
for param_partition, grad_partition in zip(self.fp32_groups_flat_partition,
298+
self.fp32_groups_gradient_flat_partition):
299+
# In case of grad acc dtype different than FP32, need to cast to high precision.
300+
param_partition.grad = grad_partition.to(
301+
param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition
302+
320303
self.optimizer.step()
321304

305+
if self.grad_acc_dtype is not torch.float32:
306+
for param_partition in self.fp32_groups_flat_partition:
307+
param_partition.grad = None
308+
322309
# We need to link optimizer state after the first step() call
323310
self._lazy_init_hp_params_optimizer_state()
324311

0 commit comments

Comments
 (0)