From 92162193e511772d1607876baa081fa9646f3b0f Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 16 Feb 2024 10:15:54 -0700 Subject: [PATCH] Add dist ckpt support for regular optimizers (#7749) (#8293) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add dist ckpt support for regular optimizers * [tutorial] fixed missing RIR scripts file. (#8257) * fix imports * imports fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * ci imports fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert asr notebook * revert asr notebook --------- Signed-off-by: Mikołaj Błaż Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: dimapihtar Co-authored-by: mikolajblaz Co-authored-by: Eric Harper Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> Co-authored-by: dimapihtar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- nemo/collections/nlp/parts/nlp_overrides.py | 19 +++++++++++++++++-- nemo/core/optim/optimizer_with_main_params.py | 13 ------------- nemo/core/optim/optimizers.py | 15 +++++++++++++++ 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 1012fdf71405b..cde0188dff20a 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -60,12 +60,14 @@ from nemo.collections.nlp.parts import utils_funcs from nemo.core.connectors.save_restore_connector import SaveRestoreConnector from nemo.core.optim import MainParamsOptimizerWrapper +from nemo.core.optim.optimizers import init_optimizer_states from nemo.utils import AppState, logging from nemo.utils.get_rank import is_global_rank_zero from nemo.utils.model_utils import ckpt_to_dir, inject_model_parallel_rank, uninject_model_parallel_rank try: from apex.transformer.pipeline_parallel.utils import get_num_microbatches + from nemo.core.optim.distributed_adam import MegatronDistributedFusedAdam HAVE_APEX = True @@ -259,7 +261,7 @@ def optimizer_sharded_state_dict(self): ValueError: If a parameter ID does not match any model sharded parameter. """ - optimizer = self.lightning_module.optimizers(use_pl_optimizer=False) # MainParamsOptimizerWrapper + optimizer = self.lightning_module.optimizers(use_pl_optimizer=False) model_sharded_state_dict = self.lightning_module.sharded_state_dict() @@ -268,8 +270,21 @@ def optimizer_sharded_state_dict(self): key: value for key, value in model_sharded_state_dict.items() if not key.endswith('_extra_state') } - if not isinstance(optimizer, MainParamsOptimizerWrapper): + if isinstance(optimizer, MegatronDistributedFusedAdam): return optimizer.sharded_state_dict(model_sharded_state_dict) + elif not isinstance(optimizer, MainParamsOptimizerWrapper): + # Regular optimizer, e.g. Adam or FusedAdam + init_optimizer_states(optimizer) + optimizer_state_dict = optimizer.state_dict() + id_to_sharded_param_map = get_param_id_to_sharded_param_map( + model_sharded_state_dict=model_sharded_state_dict, + optim_params_iter=itertools.chain.from_iterable(g['params'] for g in optimizer.param_groups), + ) + optim_state_to_sharding_state(optimizer_state_dict, id_to_sharded_param_map) + return optimizer_state_dict + + # MainParamsOptimizerWrapper + init_optimizer_states(optimizer.optimizer) optimizer_state_dict = optimizer.state_dict() diff --git a/nemo/core/optim/optimizer_with_main_params.py b/nemo/core/optim/optimizer_with_main_params.py index 680b82ed7201b..7f8794f746df8 100644 --- a/nemo/core/optim/optimizer_with_main_params.py +++ b/nemo/core/optim/optimizer_with_main_params.py @@ -312,9 +312,6 @@ def __init__( self.fp32_from_float16_groups.append(fp32_from_float16_params_this_group) self.fp32_from_fp32_groups.append(fp32_params_this_group) - # init exp_avg and exp_avg_sq before loading optimizer state, needed for dist checkpointing - self._init_opt_state() - # Leverage state_dict() and load_state_dict() to # recast preexisting per-param state tensors self.optimizer.load_state_dict(self.optimizer.state_dict()) @@ -543,13 +540,3 @@ def _set_defaults(self, value): self.optimizer.defaults = value defaults = property(_get_defaults, _set_defaults) - - def _init_opt_state(self): - """ - Initialize the optimizer state with zero tensors for 'exp_avg' and 'exp_avg_sq' of each parameter. - """ - for group in self.optimizer.param_groups: - for p in group['params']: - if len(self.optimizer.state[p]) == 0: - self.optimizer.state[p]['exp_avg'] = torch.zeros_like(p.data) - self.optimizer.state[p]['exp_avg_sq'] = torch.zeros_like(p.data) diff --git a/nemo/core/optim/optimizers.py b/nemo/core/optim/optimizers.py index 1d52a9bf10f84..2cc6be0dfc231 100644 --- a/nemo/core/optim/optimizers.py +++ b/nemo/core/optim/optimizers.py @@ -200,3 +200,18 @@ def get_optimizer(name: str, **kwargs: Optional[Dict[str, Any]]) -> Optimizer: optimizer = AVAILABLE_OPTIMIZERS[name] optimizer = partial(optimizer, **kwargs) return optimizer + + +def init_optimizer_states(optimizer: Optimizer): + adam_nondist_optims = (optim.Adam, optim.AdamW) + if HAVE_APEX: + adam_nondist_optims += (FusedAdam,) + if isinstance(optimizer, adam_nondist_optims): + for group in optimizer.param_groups: + for p in group['params']: + state = optimizer.state[p] + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format) + state['exp_avg_sq'] = torch.zeros_like(p.data, memory_format=torch.preserve_format) + if group.get('amsgrad'): + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)