@@ -197,10 +197,6 @@ def _setup_for_real_optimizer(self):
197
197
198
198
see_memory_usage (f'after initializing group { i } ' , force = True )
199
199
200
- see_memory_usage ('before initialize_optimizer' , force = True )
201
- self .initialize_optimizer_states ()
202
- see_memory_usage ('end initialize_optimizer' , force = True )
203
-
204
200
self ._grad_acc_hooks = []
205
201
if self .immediate_grad_update :
206
202
self .create_grad_acc_hooks ()
@@ -252,25 +248,6 @@ def _lazy_init_hp_params_optimizer_state(self):
252
248
self .optimizer .state )
253
249
self ._hp_optimizer_states_linked = True
254
250
255
- def initialize_optimizer_states (self ):
256
- """Take an optimizer step with zero-valued gradients to allocate internal
257
- optimizer state.
258
-
259
- This helps prevent memory fragmentation by allocating optimizer state at the
260
- beginning of training instead of after activations have been allocated.
261
- """
262
- for param_partition , grad_partition in zip (self .fp32_groups_flat_partition ,
263
- self .fp32_groups_gradient_flat_partition ):
264
- # In case of grad acc dtype different than FP32, need to cast to high precision.
265
- param_partition .grad = grad_partition .to (
266
- param_partition .dtype ) if grad_partition .dtype != param_partition .dtype else grad_partition
267
-
268
- if self .grad_acc_dtype is not torch .float32 :
269
- for param_partition in self .fp32_groups_flat_partition :
270
- param_partition .grad = None
271
-
272
- self .clear_hp_grads ()
273
-
274
251
def _split_flat_tensor (self , flat_tensor , num_elem_list ):
275
252
assert sum (num_elem_list ) <= flat_tensor .numel ()
276
253
tensor_list = []
@@ -317,8 +294,18 @@ def step(self, closure=None):
317
294
mpu = self .mpu ,
318
295
use_graph = self .graph_harvesting )
319
296
297
+ for param_partition , grad_partition in zip (self .fp32_groups_flat_partition ,
298
+ self .fp32_groups_gradient_flat_partition ):
299
+ # In case of grad acc dtype different than FP32, need to cast to high precision.
300
+ param_partition .grad = grad_partition .to (
301
+ param_partition .dtype ) if grad_partition .dtype != param_partition .dtype else grad_partition
302
+
320
303
self .optimizer .step ()
321
304
305
+ if self .grad_acc_dtype is not torch .float32 :
306
+ for param_partition in self .fp32_groups_flat_partition :
307
+ param_partition .grad = None
308
+
322
309
# We need to link optimizer state after the first step() call
323
310
self ._lazy_init_hp_params_optimizer_state ()
324
311
0 commit comments