Skip to content

Commit 922c777

Browse files
mikolajblazericharperXuesongYangdimapihtarpre-commit-ci[bot]
authored andcommitted
Add dist ckpt support for regular optimizers (NVIDIA#7749)
* Add dist ckpt support for regular optimizers Signed-off-by: Mikołaj Błaż <[email protected]> * [tutorial] fixed missing RIR scripts file. (NVIDIA#8257) Signed-off-by: Xuesong Yang <[email protected]> * fix imports Signed-off-by: dimapihtar <[email protected]> * imports fix Signed-off-by: dimapihtar <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * ci imports fix Signed-off-by: dimapihtar <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert asr notebook Signed-off-by: dimapihtar <[email protected]> * revert asr notebook Signed-off-by: dimapihtar <[email protected]> --------- Signed-off-by: Mikołaj Błaż <[email protected]> Signed-off-by: Xuesong Yang <[email protected]> Signed-off-by: dimapihtar <[email protected]> Co-authored-by: Eric Harper <[email protected]> Co-authored-by: Xuesong Yang <[email protected]> Co-authored-by: Dmytro Pykhtar <[email protected]> Co-authored-by: dimapihtar <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent b1016ae commit 922c777

File tree

3 files changed

+32
-15
lines changed

3 files changed

+32
-15
lines changed

nemo/collections/nlp/parts/nlp_overrides.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,14 @@
6060
from nemo.collections.nlp.parts import utils_funcs
6161
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
6262
from nemo.core.optim import MainParamsOptimizerWrapper
63+
from nemo.core.optim.optimizers import init_optimizer_states
6364
from nemo.utils import AppState, logging
6465
from nemo.utils.get_rank import is_global_rank_zero
6566
from nemo.utils.model_utils import ckpt_to_dir, inject_model_parallel_rank, uninject_model_parallel_rank
6667

6768
try:
6869
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
70+
from nemo.core.optim.distributed_adam import MegatronDistributedFusedAdam
6971

7072
HAVE_APEX = True
7173

@@ -259,7 +261,7 @@ def optimizer_sharded_state_dict(self):
259261
ValueError: If a parameter ID does not match any model sharded parameter.
260262
"""
261263

262-
optimizer = self.lightning_module.optimizers(use_pl_optimizer=False) # MainParamsOptimizerWrapper
264+
optimizer = self.lightning_module.optimizers(use_pl_optimizer=False)
263265

264266
model_sharded_state_dict = self.lightning_module.sharded_state_dict()
265267

@@ -268,8 +270,21 @@ def optimizer_sharded_state_dict(self):
268270
key: value for key, value in model_sharded_state_dict.items() if not key.endswith('_extra_state')
269271
}
270272

271-
if not isinstance(optimizer, MainParamsOptimizerWrapper):
273+
if isinstance(optimizer, MegatronDistributedFusedAdam):
272274
return optimizer.sharded_state_dict(model_sharded_state_dict)
275+
elif not isinstance(optimizer, MainParamsOptimizerWrapper):
276+
# Regular optimizer, e.g. Adam or FusedAdam
277+
init_optimizer_states(optimizer)
278+
optimizer_state_dict = optimizer.state_dict()
279+
id_to_sharded_param_map = get_param_id_to_sharded_param_map(
280+
model_sharded_state_dict=model_sharded_state_dict,
281+
optim_params_iter=itertools.chain.from_iterable(g['params'] for g in optimizer.param_groups),
282+
)
283+
optim_state_to_sharding_state(optimizer_state_dict, id_to_sharded_param_map)
284+
return optimizer_state_dict
285+
286+
# MainParamsOptimizerWrapper
287+
init_optimizer_states(optimizer.optimizer)
273288

274289
optimizer_state_dict = optimizer.state_dict()
275290

nemo/core/optim/optimizer_with_main_params.py

-13
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,6 @@ def __init__(
312312
self.fp32_from_float16_groups.append(fp32_from_float16_params_this_group)
313313
self.fp32_from_fp32_groups.append(fp32_params_this_group)
314314

315-
# init exp_avg and exp_avg_sq before loading optimizer state, needed for dist checkpointing
316-
self._init_opt_state()
317-
318315
# Leverage state_dict() and load_state_dict() to
319316
# recast preexisting per-param state tensors
320317
self.optimizer.load_state_dict(self.optimizer.state_dict())
@@ -543,13 +540,3 @@ def _set_defaults(self, value):
543540
self.optimizer.defaults = value
544541

545542
defaults = property(_get_defaults, _set_defaults)
546-
547-
def _init_opt_state(self):
548-
"""
549-
Initialize the optimizer state with zero tensors for 'exp_avg' and 'exp_avg_sq' of each parameter.
550-
"""
551-
for group in self.optimizer.param_groups:
552-
for p in group['params']:
553-
if len(self.optimizer.state[p]) == 0:
554-
self.optimizer.state[p]['exp_avg'] = torch.zeros_like(p.data)
555-
self.optimizer.state[p]['exp_avg_sq'] = torch.zeros_like(p.data)

nemo/core/optim/optimizers.py

+15
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,18 @@ def get_optimizer(name: str, **kwargs: Optional[Dict[str, Any]]) -> Optimizer:
200200
optimizer = AVAILABLE_OPTIMIZERS[name]
201201
optimizer = partial(optimizer, **kwargs)
202202
return optimizer
203+
204+
205+
def init_optimizer_states(optimizer: Optimizer):
206+
adam_nondist_optims = (optim.Adam, optim.AdamW)
207+
if HAVE_APEX:
208+
adam_nondist_optims += (FusedAdam,)
209+
if isinstance(optimizer, adam_nondist_optims):
210+
for group in optimizer.param_groups:
211+
for p in group['params']:
212+
state = optimizer.state[p]
213+
if len(state) == 0:
214+
state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
215+
state['exp_avg_sq'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
216+
if group.get('amsgrad'):
217+
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

0 commit comments

Comments
 (0)