-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[NeMo-UX] Add mixed-precision plugin (#9065)
* Adding MegatronParallel * Move over _strategy_liMegatronCheckpointIO * Adding GPTModel & MockDataModule * Adding mixed-precision to NeMo * Fix import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert unintended changes Signed-off-by: Chen Cui <[email protected]> * clean up code and reinstate mix precision tests Signed-off-by: Chen Cui <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean up Signed-off-by: Chen Cui <[email protected]> * use cpu for unit test Signed-off-by: Chen Cui <[email protected]> --------- Signed-off-by: Chen Cui <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Chen Cui <[email protected]> Signed-off-by: Ao Tang <[email protected]>
- Loading branch information
Showing
4 changed files
with
232 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
from nemo.lightning.pytorch.plugins.data_sampler import MegatronDataSampler | ||
from nemo.lightning.pytorch.plugins.mixed_precision import MegatronMixedPrecision | ||
|
||
__all__ = ["MegatronDataSampler"] | ||
__all__ = [ | ||
"MegatronDataSampler", | ||
"MegatronMixedPrecision", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from contextlib import contextmanager | ||
from typing import Any, Callable, Generator, List, Literal, Tuple, TypeVar, Union | ||
|
||
import pytorch_lightning as pl | ||
import torch | ||
from pytorch_lightning.plugins.precision import MixedPrecision | ||
from torch.nn import Module | ||
from torch.optim import Optimizer | ||
|
||
from nemo.lightning._strategy_lib import GradScaler | ||
|
||
AnyT = TypeVar("AnyT") | ||
|
||
|
||
class MegatronMixedPrecision(MixedPrecision): | ||
def __init__(self, precision: Literal["16-mixed", "bf16-mixed"], amp_O2: bool = True, device="cuda",) -> None: | ||
if precision == "bf16-mixed": | ||
scaler = None | ||
else: | ||
scaler = GradScaler(init_scale=2 ** 32, growth_interval=1000, hysteresis=2) | ||
|
||
super().__init__(precision, device, scaler) | ||
|
||
# MixedPrecisionPlugin class in PTL >= 2.0 takes only "16-mixed" or "bf16-mixed" for precision arg | ||
if precision == "16-mixed": | ||
dtype = torch.float16 | ||
|
||
def float16_convertor(val): | ||
return val.half() | ||
|
||
elif precision == "bf16-mixed": | ||
dtype = torch.bfloat16 | ||
|
||
def float16_convertor(val): | ||
return val.bfloat16() | ||
|
||
else: | ||
raise ValueError("precision must be '16-mixed' or 'bf16-mixed'") | ||
|
||
self.dtype = dtype | ||
torch.set_autocast_gpu_dtype(dtype) | ||
self.float16_convertor = float16_convertor | ||
self.amp_O2 = amp_O2 | ||
|
||
def connect( | ||
self, model: Module, optimizers: List[Optimizer], lr_schedulers: List[Any] | ||
) -> Tuple[Module, List[Optimizer], List[Any]]: | ||
"""Connects this plugin to the accelerator and the training process.""" | ||
from nemo.core.optim import MainParamsOptimizerWrapper | ||
|
||
if not optimizers or not self.amp_O2 or isinstance(optimizers[0], MainParamsOptimizerWrapper): | ||
return model, optimizers, lr_schedulers | ||
|
||
_optimizers = [*optimizers] | ||
_optimizers[0] = self.convert_optimizer(_optimizers[0]) | ||
|
||
return model, _optimizers, lr_schedulers | ||
|
||
def convert_module(self, module: Module) -> Module: | ||
"""Convert the module parameters to the precision type this plugin handles. | ||
This is optional and depends on the precision limitations during optimization. | ||
""" | ||
if self.precision == "bf16-mixed": | ||
return module.bfloat16() | ||
if self.precision == "16-mixed": | ||
return module.half() | ||
|
||
return module | ||
|
||
def convert_optimizer(self, optimizer: Optimizer) -> Optimizer: | ||
"""Convert the optimizer parameters to the precision type this plugin handles. | ||
This is optional and depends on the precision limitations during optimization. | ||
""" | ||
from nemo.core.optim import MainParamsOptimizerWrapper | ||
|
||
if isinstance(optimizer, MainParamsOptimizerWrapper) or not self.amp_O2: | ||
return optimizer | ||
|
||
return MainParamsOptimizerWrapper(optimizer, fp32_grad_accum=True, contiguous_grad_bucket=True,) | ||
|
||
def convert_input(self, data: AnyT) -> AnyT: | ||
"""Convert model inputs (forward) to the floating point precision type of this plugin. | ||
Note: MegatronStrategy will take care of only doing this when: | ||
parallel_state.is_pipeline_first_stage() | ||
""" | ||
from megatron.core.transformer.module import fp32_to_float16 | ||
|
||
return fp32_to_float16(data, self.float16_convertor) | ||
|
||
def convert_output(self, data: AnyT) -> AnyT: | ||
"""Convert outputs to the floating point precision type expected after model's forward. | ||
Note: MegatronStrategy will take care of only doing this when: | ||
parallel_state.is_pipeline_last_stage() | ||
""" | ||
from megatron.core.transformer.module import float16_to_fp32 | ||
|
||
return float16_to_fp32(data) | ||
|
||
def optimizer_step( | ||
self, | ||
optimizer: torch.optim.Optimizer, | ||
model: Union["pl.LightningModule", torch.nn.Module], | ||
closure: Callable[[], Any], | ||
**kwargs: Any, | ||
) -> None: | ||
from nemo.core.optim import MainParamsOptimizerWrapper | ||
|
||
if not self.amp_O2 and not isinstance(optimizer, MainParamsOptimizerWrapper): | ||
return super().optimizer_step(optimizer, model, closure, **kwargs) | ||
|
||
if self.scaler is None: | ||
assert optimizer.fp32_grad_accumulation, "BF16 uses FP32 grad accumulation" | ||
_ = closure() | ||
self._after_closure(model, optimizer) | ||
return optimizer.step(**kwargs) | ||
|
||
assert not optimizer.fp32_grad_accumulation, "FP16 uses FP16 grad accumulation" | ||
closure_result = closure() | ||
|
||
# TODO: Add an option for merged all-reduce | ||
|
||
# cast fp16 grads to fp32 and copy to main grads, which are used for unscale and param update | ||
optimizer.copy_model_grads_to_main_grads() | ||
# `unscale` after the closure is executed but before the `on_before_optimizer_step` hook. | ||
# unscale main (fp32) gradients | ||
self.scaler.unscale_(optimizer) | ||
self._after_closure(model, optimizer) | ||
skipped_backward = closure_result is None | ||
# in manual optimization, the closure does not return a value | ||
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward: | ||
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found | ||
self.scaler.step(optimizer, **kwargs) | ||
self.scaler.update() | ||
|
||
@contextmanager | ||
def forward_context(self) -> Generator[None, None, None]: | ||
"""No explicit precision casting. Inputs are supposed to be manually casted.""" | ||
try: | ||
yield | ||
finally: | ||
pass | ||
|
||
|
||
__all__ = ["MegatronMixedPrecision"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters