60
60
from nemo .collections .nlp .parts import utils_funcs
61
61
from nemo .core .connectors .save_restore_connector import SaveRestoreConnector
62
62
from nemo .core .optim import MainParamsOptimizerWrapper
63
+ from nemo .core .optim .optimizers import init_optimizer_states
63
64
from nemo .utils import AppState , logging
64
65
from nemo .utils .get_rank import is_global_rank_zero
65
66
from nemo .utils .model_utils import ckpt_to_dir , inject_model_parallel_rank , uninject_model_parallel_rank
66
67
67
68
try :
68
69
from apex .transformer .pipeline_parallel .utils import get_num_microbatches
70
+ from nemo .core .optim .distributed_adam import MegatronDistributedFusedAdam
69
71
70
72
HAVE_APEX = True
71
73
@@ -259,7 +261,7 @@ def optimizer_sharded_state_dict(self):
259
261
ValueError: If a parameter ID does not match any model sharded parameter.
260
262
"""
261
263
262
- optimizer = self .lightning_module .optimizers (use_pl_optimizer = False ) # MainParamsOptimizerWrapper
264
+ optimizer = self .lightning_module .optimizers (use_pl_optimizer = False )
263
265
264
266
model_sharded_state_dict = self .lightning_module .sharded_state_dict ()
265
267
@@ -268,8 +270,21 @@ def optimizer_sharded_state_dict(self):
268
270
key : value for key , value in model_sharded_state_dict .items () if not key .endswith ('_extra_state' )
269
271
}
270
272
271
- if not isinstance (optimizer , MainParamsOptimizerWrapper ):
273
+ if isinstance (optimizer , MegatronDistributedFusedAdam ):
272
274
return optimizer .sharded_state_dict (model_sharded_state_dict )
275
+ elif not isinstance (optimizer , MainParamsOptimizerWrapper ):
276
+ # Regular optimizer, e.g. Adam or FusedAdam
277
+ init_optimizer_states (optimizer )
278
+ optimizer_state_dict = optimizer .state_dict ()
279
+ id_to_sharded_param_map = get_param_id_to_sharded_param_map (
280
+ model_sharded_state_dict = model_sharded_state_dict ,
281
+ optim_params_iter = itertools .chain .from_iterable (g ['params' ] for g in optimizer .param_groups ),
282
+ )
283
+ optim_state_to_sharding_state (optimizer_state_dict , id_to_sharded_param_map )
284
+ return optimizer_state_dict
285
+
286
+ # MainParamsOptimizerWrapper
287
+ init_optimizer_states (optimizer .optimizer )
273
288
274
289
optimizer_state_dict = optimizer .state_dict ()
275
290
0 commit comments