Skip to content

Commit

Permalink
[NeMo-UX] Add mixed-precision plugin (#9065)
Browse files Browse the repository at this point in the history
* Adding MegatronParallel

* Move over _strategy_liMegatronCheckpointIO

* Adding GPTModel & MockDataModule

* Adding mixed-precision to NeMo

* Fix import

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert unintended changes

Signed-off-by: Chen Cui <[email protected]>

* clean up code and reinstate mix precision tests

Signed-off-by: Chen Cui <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* clean up

Signed-off-by: Chen Cui <[email protected]>

* use cpu for unit test

Signed-off-by: Chen Cui <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Chen Cui <[email protected]>
Signed-off-by: Ao Tang <[email protected]>
  • Loading branch information
3 people authored and suiyoubi committed May 2, 2024
1 parent 65cefa8 commit 0310b6d
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 57 deletions.
11 changes: 9 additions & 2 deletions nemo/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pytorch_lightning import plugins as _pl_plugins

from nemo.lightning.base import get_vocab_size, teardown
from nemo.lightning.pytorch.plugins import MegatronDataSampler
from nemo.lightning.pytorch.plugins import MegatronDataSampler, MegatronMixedPrecision
from nemo.lightning.pytorch.plugins import data_sampler as _data_sampler
from nemo.lightning.pytorch.strategies import MegatronStrategy
from nemo.lightning.pytorch.trainer import Trainer
Expand All @@ -22,4 +22,11 @@ def _is_slurm_interactive_mode():
_pl_plugins._PLUGIN_INPUT = Union[_pl_plugins._PLUGIN_INPUT, _data_sampler.DataSampler] # noqa: SLF001


__all__ = ["MegatronStrategy", "MegatronDataSampler", "Trainer", "get_vocab_size", "teardown"]
__all__ = [
"MegatronStrategy",
"MegatronDataSampler",
"MegatronMixedPrecision",
"Trainer",
"get_vocab_size",
"teardown",
]
6 changes: 5 additions & 1 deletion nemo/lightning/pytorch/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from nemo.lightning.pytorch.plugins.data_sampler import MegatronDataSampler
from nemo.lightning.pytorch.plugins.mixed_precision import MegatronMixedPrecision

__all__ = ["MegatronDataSampler"]
__all__ = [
"MegatronDataSampler",
"MegatronMixedPrecision",
]
166 changes: 166 additions & 0 deletions nemo/lightning/pytorch/plugins/mixed_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager
from typing import Any, Callable, Generator, List, Literal, Tuple, TypeVar, Union

import pytorch_lightning as pl
import torch
from pytorch_lightning.plugins.precision import MixedPrecision
from torch.nn import Module
from torch.optim import Optimizer

from nemo.lightning._strategy_lib import GradScaler

AnyT = TypeVar("AnyT")


class MegatronMixedPrecision(MixedPrecision):
def __init__(self, precision: Literal["16-mixed", "bf16-mixed"], amp_O2: bool = True, device="cuda",) -> None:
if precision == "bf16-mixed":
scaler = None
else:
scaler = GradScaler(init_scale=2 ** 32, growth_interval=1000, hysteresis=2)

super().__init__(precision, device, scaler)

# MixedPrecisionPlugin class in PTL >= 2.0 takes only "16-mixed" or "bf16-mixed" for precision arg
if precision == "16-mixed":
dtype = torch.float16

def float16_convertor(val):
return val.half()

elif precision == "bf16-mixed":
dtype = torch.bfloat16

def float16_convertor(val):
return val.bfloat16()

else:
raise ValueError("precision must be '16-mixed' or 'bf16-mixed'")

self.dtype = dtype
torch.set_autocast_gpu_dtype(dtype)
self.float16_convertor = float16_convertor
self.amp_O2 = amp_O2

def connect(
self, model: Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
) -> Tuple[Module, List[Optimizer], List[Any]]:
"""Connects this plugin to the accelerator and the training process."""
from nemo.core.optim import MainParamsOptimizerWrapper

if not optimizers or not self.amp_O2 or isinstance(optimizers[0], MainParamsOptimizerWrapper):
return model, optimizers, lr_schedulers

_optimizers = [*optimizers]
_optimizers[0] = self.convert_optimizer(_optimizers[0])

return model, _optimizers, lr_schedulers

def convert_module(self, module: Module) -> Module:
"""Convert the module parameters to the precision type this plugin handles.
This is optional and depends on the precision limitations during optimization.
"""
if self.precision == "bf16-mixed":
return module.bfloat16()
if self.precision == "16-mixed":
return module.half()

return module

def convert_optimizer(self, optimizer: Optimizer) -> Optimizer:
"""Convert the optimizer parameters to the precision type this plugin handles.
This is optional and depends on the precision limitations during optimization.
"""
from nemo.core.optim import MainParamsOptimizerWrapper

if isinstance(optimizer, MainParamsOptimizerWrapper) or not self.amp_O2:
return optimizer

return MainParamsOptimizerWrapper(optimizer, fp32_grad_accum=True, contiguous_grad_bucket=True,)

def convert_input(self, data: AnyT) -> AnyT:
"""Convert model inputs (forward) to the floating point precision type of this plugin.
Note: MegatronStrategy will take care of only doing this when:
parallel_state.is_pipeline_first_stage()
"""
from megatron.core.transformer.module import fp32_to_float16

return fp32_to_float16(data, self.float16_convertor)

def convert_output(self, data: AnyT) -> AnyT:
"""Convert outputs to the floating point precision type expected after model's forward.
Note: MegatronStrategy will take care of only doing this when:
parallel_state.is_pipeline_last_stage()
"""
from megatron.core.transformer.module import float16_to_fp32

return float16_to_fp32(data)

def optimizer_step(
self,
optimizer: torch.optim.Optimizer,
model: Union["pl.LightningModule", torch.nn.Module],
closure: Callable[[], Any],
**kwargs: Any,
) -> None:
from nemo.core.optim import MainParamsOptimizerWrapper

if not self.amp_O2 and not isinstance(optimizer, MainParamsOptimizerWrapper):
return super().optimizer_step(optimizer, model, closure, **kwargs)

if self.scaler is None:
assert optimizer.fp32_grad_accumulation, "BF16 uses FP32 grad accumulation"
_ = closure()
self._after_closure(model, optimizer)
return optimizer.step(**kwargs)

assert not optimizer.fp32_grad_accumulation, "FP16 uses FP16 grad accumulation"
closure_result = closure()

# TODO: Add an option for merged all-reduce

# cast fp16 grads to fp32 and copy to main grads, which are used for unscale and param update
optimizer.copy_model_grads_to_main_grads()
# `unscale` after the closure is executed but before the `on_before_optimizer_step` hook.
# unscale main (fp32) gradients
self.scaler.unscale_(optimizer)
self._after_closure(model, optimizer)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
self.scaler.step(optimizer, **kwargs)
self.scaler.update()

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""No explicit precision casting. Inputs are supposed to be manually casted."""
try:
yield
finally:
pass


__all__ = ["MegatronMixedPrecision"]
106 changes: 52 additions & 54 deletions tests/lightning/test_megatron_parallel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import defaultdict

import pytest
from megatron.core import parallel_state
from torch import nn

from nemo import lightning as nl
Expand All @@ -24,11 +25,10 @@ def forward(self, x):

return DummyModule()

# TODO (chcui): Uncomment this test when we merge mixed-precision
# @pytest.fixture
# def mock_precision_plugin(self, mocker):
# """Fixture to create a mock precision plugin."""
# return nl.MegatronMixedPrecision(precision="bf16-mixed")
@pytest.fixture
def mock_precision_plugin(self, mocker):
"""Fixture to create a mock precision plugin."""
return nl.MegatronMixedPrecision(precision="bf16-mixed")

@pytest.fixture
def mock_callbacks(self, mocker):
Expand Down Expand Up @@ -64,55 +64,53 @@ def test_init_with_defaults(self, mocker, mock_pipeline):
assert megatron_parallel.forward_step == mp.default_forward_step
assert megatron_parallel.loss_reduction is None

# TODO (chcui): Uncomment this test when we merge mixed-precision
# def test_init_with_custom_parameters(
# self,
# mocker,
# mock_pipeline,
# mock_precision_plugin,
# mock_callbacks,
# mock_data_step,
# mock_forward_step,
# mock_loss_reduction
# ):
# """Test __init__ with custom parameters."""
# mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_world_size', return_value=1)
# mocker.patch('megatron.core.parallel_state.model_parallel_is_initialized', return_value=False)
#
# megatron_parallel = mp.MegatronParallel(
# pipeline=mock_pipeline,
# precision_plugin=mock_precision_plugin,
# callbacks=mock_callbacks,
# data_step=mock_data_step,
# forward_step=mock_forward_step,
# loss_reduction=mock_loss_reduction
# )
#
# assert megatron_parallel.pipeline == mock_pipeline
# assert megatron_parallel.precision_plugin == mock_precision_plugin
# assert megatron_parallel.callbacks == mock_callbacks
# assert megatron_parallel.data_step == mock_data_step
# assert megatron_parallel.forward_step == mock_forward_step
# assert megatron_parallel.loss_reduction == mock_loss_reduction

# TODO: Comment-out this test when we merge nemo.io
# def test_init_with_virtual_pipeline(self, mocker, mock_pipeline):
# """Test __init__ with virtual pipeline model parallel world size."""
# mocker.patch('torch.distributed.get_rank', return_value=1)
# mocker.patch('megatron.core.parallel_state.get_tensor_model_parallel_group', return_value=1)
# mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_group', return_value=1)
# mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_world_size', return_value=2)
# mocker.patch('megatron.core.parallel_state.model_parallel_is_initialized', return_value=True)
# mocker.patch('megatron.core.parallel_state.set_virtual_pipeline_model_parallel_world_size')
# mocker.patch('megatron.core.parallel_state.set_virtual_pipeline_model_parallel_rank')
# mocker.patch('nemo_ext.lightning._strategy_lib.init_lightning_module', return_value=mock_pipeline)

# megatron_parallel = mp.MegatronParallel(mock_pipeline, vp_size=2)

# assert len(megatron_parallel.pipeline) == 2
# assert all(isinstance(mod, nn.Module) for mod in megatron_parallel.pipeline)
# megatron.core.parallel_state.set_virtual_pipeline_model_parallel_world_size.assert_called_once_with(2)
# assert megatron.core.parallel_state.set_virtual_pipeline_model_parallel_rank.call_count == 1
def test_init_with_custom_parameters(
self,
mocker,
mock_pipeline,
mock_precision_plugin,
mock_callbacks,
mock_data_step,
mock_forward_step,
mock_loss_reduction,
):
"""Test __init__ with custom parameters."""
mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_world_size', return_value=1)
mocker.patch('megatron.core.parallel_state.model_parallel_is_initialized', return_value=False)

megatron_parallel = mp.MegatronParallel(
pipeline=mock_pipeline,
precision_plugin=mock_precision_plugin,
callbacks=mock_callbacks,
data_step=mock_data_step,
forward_step=mock_forward_step,
loss_reduction=mock_loss_reduction,
)

assert megatron_parallel.pipeline == mock_pipeline
assert megatron_parallel.precision_plugin == mock_precision_plugin
assert megatron_parallel.callbacks == mock_callbacks
assert megatron_parallel.data_step == mock_data_step
assert megatron_parallel.forward_step == mock_forward_step
assert megatron_parallel.loss_reduction == mock_loss_reduction

def test_init_with_virtual_pipeline(self, mocker, mock_pipeline):
"""Test __init__ with virtual pipeline model parallel world size."""
mocker.patch('torch.distributed.get_rank', return_value=1)
mocker.patch('megatron.core.parallel_state.get_tensor_model_parallel_group', return_value=1)
mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_group', return_value=1)
mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_world_size', return_value=2)
mocker.patch('megatron.core.parallel_state.model_parallel_is_initialized', return_value=True)
mocker.patch('megatron.core.parallel_state.set_virtual_pipeline_model_parallel_world_size')
mocker.patch('megatron.core.parallel_state.set_virtual_pipeline_model_parallel_rank')
mocker.patch('nemo.io.reinit', return_value=mock_pipeline)

megatron_parallel = mp.MegatronParallel(mock_pipeline, vp_size=2, cpu=True)

assert len(megatron_parallel.pipeline) == 2
assert all(isinstance(mod, nn.Module) for mod in megatron_parallel.pipeline)
parallel_state.set_virtual_pipeline_model_parallel_world_size.assert_called_once_with(2)
assert parallel_state.set_virtual_pipeline_model_parallel_rank.call_count == 1


class TestCallbackConnector:
Expand Down

0 comments on commit 0310b6d

Please sign in to comment.