Skip to content

Commit

Permalink
TransformerEngine fallback compute dtype (#19082)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
carmocca and awaelchli authored Dec 14, 2023
1 parent d8b6bbd commit 97469c6
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 29 deletions.
2 changes: 1 addition & 1 deletion docs/source-fabric/fundamentals/precision.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ the model and inputs can be kept in true full or half precision.
from lightning.fabric.plugins import TransformerEnginePrecision
recipe = {"fp8_format": "HYBRID", "amax_history_len": 16, "amax_compute_algo": "max"}
precision = TransformerEnginePrecision(dtype=torch.bfloat16, recipe=recipe)
precision = TransformerEnginePrecision(weights_dtype=torch.bfloat16, recipe=recipe)
fabric = Fabric(plugins=precision)
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/common/precision_intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ the model and inputs can be kept in true full or half precision.
from lightning.trainer.plugins import TransformerEnginePrecision
recipe = {"fp8_format": "HYBRID", "amax_history_len": 16, "amax_compute_algo": "max"}
precision = TransformerEnginePrecision(dtype=torch.bfloat16, recipe=recipe)
precision = TransformerEnginePrecision(weights_dtype=torch.bfloat16, recipe=recipe)
trainer = Trainer(plugins=precision)
Expand Down
9 changes: 9 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `lightning.fabric.utilities.AttributeDict` for convenient dict-attribute access to represent state in script ([#18943](https://github.com/Lightning-AI/lightning/pull/18943))


- Added `TransformerEnginePrecision(fallback_compute_dtype=)` to control the dtype of operations that don't support fp8 ([#19082](https://github.com/Lightning-AI/lightning/pull/19082))


### Changed

- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))


- Changed the `TransformerEnginePrecision(dtype=)` argument to `weights_dtype` and made it required ([#19082](https://github.com/Lightning-AI/lightning/pull/19082))


### Deprecated

-
Expand All @@ -38,6 +44,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed broadcast at initialization in `MPIEnvironment` ([#19074](https://github.com/Lightning-AI/lightning/pull/19074))


- Fixed issue where the `precision="transformer-engine"` argument would not replace layers by default ([#19082](https://github.com/Lightning-AI/lightning/pull/19082))



## [2.1.2] - 2023-11-15

Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,9 @@ def _check_and_init_precision(self) -> Precision:
if self._precision_input == "64-true":
return DoublePrecision()
if self._precision_input == "transformer-engine":
return TransformerEnginePrecision(dtype=torch.bfloat16)
return TransformerEnginePrecision(weights_dtype=torch.bfloat16)
if self._precision_input == "transformer-engine-float16":
return TransformerEnginePrecision(dtype=torch.float16)
return TransformerEnginePrecision(weights_dtype=torch.float16)

if self._precision_input == "16-mixed" and self._accelerator_flag == "cpu":
rank_zero_warn(
Expand Down
35 changes: 24 additions & 11 deletions src/lightning/fabric/plugins/precision/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
_convert_fp_tensor,
_DtypeContextManager,
)
from lightning.fabric.utilities.rank_zero import rank_zero_warn
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn

if TYPE_CHECKING:
from transformer_engine.common.recipe import DelayedScaling
Expand All @@ -42,13 +42,15 @@ class TransformerEnginePrecision(Precision):
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
Args:
dtype: The weights dtype to use.
weights_dtype: The weights dtype to use.
recipe: Recipe for the DelayedScaling
`configuration <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html#transformer_engine.common.recipe.DelayedScaling>`__.
In dict format or the dataclass format.
replace_layers: Whether to replace ``Linear`` and ``LayerNorm`` layers automatically with their Transformer
Engine alternatives. Note that they don't subclass the torch equivalents so checks like
``isinstance(l, torch.nn.Linear)`` will not pass.
fallback_compute_dtype: The compute dtype to use for operations that don't support fp8 autocast. Defaults to the
same as ``weights_dtype``.
.. note::
Expand All @@ -62,9 +64,11 @@ class TransformerEnginePrecision(Precision):

def __init__(
self,
dtype: Optional[torch.dtype] = None,
*,
weights_dtype: torch.dtype,
recipe: Optional[Union[Mapping[str, Any], "DelayedScaling"]] = None,
replace_layers: Optional[bool] = None,
fallback_compute_dtype: Optional[torch.dtype] = None,
) -> None:
if not _TRANSFORMER_ENGINE_AVAILABLE:
raise ModuleNotFoundError(str(_TRANSFORMER_ENGINE_AVAILABLE))
Expand All @@ -80,21 +84,27 @@ def __init__(
recipe["fp8_format"] = getattr(Format, recipe["fp8_format"])
recipe = DelayedScaling(**recipe)

if dtype is None:
dtype = torch.get_default_dtype()
self.dtype = dtype
self.weights_dtype = weights_dtype
self.recipe = recipe
self.replace_layers = replace_layers
self.fallback_compute_dtype = fallback_compute_dtype or weights_dtype

def convert_module(self, module: torch.nn.Module) -> torch.nn.Module:
# avoid converting if any is found. assume the user took care of it
if self.replace_layers and not any("transformer_engine.pytorch" in m.__module__ for m in module.modules()):
if any("transformer_engine.pytorch" in m.__module__ for m in module.modules()):
if self.replace_layers is True:
# info level because this is expected with `init_module`
rank_zero_info(
"`TransformerEnginePrecision(replace_layers=True)` is set but the model already contains"
" TransformerEngine layers. Skipping"
)
elif self.replace_layers in (None, True):
_convert_layers(module)
module = module.to(dtype=self.dtype)
module = module.to(dtype=self.weights_dtype)
return module

def tensor_init_context(self) -> ContextManager:
return _DtypeContextManager(self.dtype)
return _DtypeContextManager(self.weights_dtype)

def module_init_context(self) -> ContextManager:
dtype_ctx = self.tensor_init_context()
Expand All @@ -113,17 +123,20 @@ def module_init_context(self) -> ContextManager:
return stack

def forward_context(self) -> ContextManager:
dtype_ctx = _DtypeContextManager(self.dtype)
dtype_ctx = _DtypeContextManager(self.weights_dtype)
fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype)
import transformer_engine.pytorch as te

autocast_ctx = te.fp8_autocast(enabled=True, fp8_recipe=self.recipe)
stack = ExitStack()
stack.enter_context(dtype_ctx)
# enable an outer fallback autocast for operations that do not support fp8
stack.enter_context(fallback_autocast_ctx)
stack.enter_context(autocast_ctx)
return stack

def convert_input(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.dtype)
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.weights_dtype)

def convert_output(self, data: Any) -> Any:
return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
Expand Down
9 changes: 9 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The Trainer now restores the training mode set through `.train()` or `.eval()` on a submodule-level when switching from validation to training ([#18951](https://github.com/Lightning-AI/lightning/pull/18951))


- Added `TransformerEnginePrecision(fallback_compute_dtype=)` to control the dtype of operations that don't support fp8 ([#19082](https://github.com/Lightning-AI/lightning/pull/19082))


### Changed

- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))
Expand All @@ -32,6 +35,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `LightningModule.load_from_checkpoint()` function now calls `.configure_model()` on the model if it is overridden, to ensure all layers can be loaded from the checkpoint ([#19036](https://github.com/Lightning-AI/lightning/pull/19036))


- Changed the `TransformerEnginePrecision(dtype=)` argument to `weights_dtype` and made it required ([#19082](https://github.com/Lightning-AI/lightning/pull/19082))


### Deprecated

- Deprecated all precision plugin classes under `lightning.pytorch.plugins` with the suffix `Plugin` in the name ([#18840](https://github.com/Lightning-AI/lightning/pull/18840))
Expand Down Expand Up @@ -65,6 +71,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed broadcast at initialization in `MPIEnvironment` ([#19074](https://github.com/Lightning-AI/lightning/pull/19074))


- Fixed issue where the `precision="transformer-engine"` argument would not replace layers by default ([#19082](https://github.com/Lightning-AI/lightning/pull/19082))


## [2.1.2] - 2023-11-15

### Fixed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -549,9 +549,9 @@ def _check_and_init_precision(self) -> Precision:
if self._precision_flag == "64-true":
return DoublePrecision()
if self._precision_flag == "transformer-engine":
return TransformerEnginePrecision(dtype=torch.bfloat16)
return TransformerEnginePrecision(weights_dtype=torch.bfloat16)
if self._precision_flag == "transformer-engine-float16":
return TransformerEnginePrecision(dtype=torch.float16)
return TransformerEnginePrecision(weights_dtype=torch.float16)

if self._precision_flag == "16-mixed" and self._accelerator_flag == "cpu":
rank_zero_warn(
Expand Down
10 changes: 5 additions & 5 deletions tests/tests_fabric/plugins/precision/test_transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,20 @@ def test_transformer_engine_plugin(monkeypatch):

connector = _Connector(precision="transformer-engine")
assert isinstance(connector.precision, TransformerEnginePrecision)
assert connector.precision.dtype is torch.bfloat16
assert connector.precision.weights_dtype is torch.bfloat16
connector = _Connector(precision="transformer-engine-float16")
assert connector.precision.dtype is torch.float16
assert connector.precision.weights_dtype is torch.float16

recipe_mock.reset_mock()
precision = TransformerEnginePrecision()
precision = TransformerEnginePrecision(weights_dtype=torch.float32)
connector = _Connector(plugins=precision)
assert connector.precision is precision
assert precision.dtype == torch.float32
assert precision.weights_dtype == torch.float32
recipe_mock.DelayedScaling.assert_called_once_with()

recipe_mock.reset_mock()
recipe = {"foo": 0, "fp8_format": "HYBRID"}
precision = TransformerEnginePrecision(dtype=torch.float16, recipe=recipe)
precision = TransformerEnginePrecision(weights_dtype=torch.float16, recipe=recipe)
connector = _Connector(plugins=precision)
assert connector.precision is precision
recipe_mock.DelayedScaling.assert_called_once_with(foo=0, fp8_format=recipe_mock.Format.HYBRID)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_transformer_engine_precision_plugin(monkeypatch):
from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecisionPlugin

with pytest.deprecated_call(match=r"The `TransformerEnginePrecisionPlugin` is deprecated"):
TransformerEnginePrecisionPlugin()
TransformerEnginePrecisionPlugin(weights_dtype=torch.float32)


def test_xla_precision_plugin(xla_available):
Expand Down
47 changes: 41 additions & 6 deletions tests/tests_pytorch/plugins/precision/test_transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License
import sys
from unittest.mock import Mock
from contextlib import nullcontext
from unittest.mock import ANY, Mock

import lightning.fabric
import pytest
import torch
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.plugins import TransformerEnginePrecision
from lightning.pytorch.trainer.connectors.accelerator_connector import _AcceleratorConnector


def test_transformer_engine_precision_plugin(monkeypatch):
import lightning.fabric # avoid breakage with standalone package

module = lightning.fabric.plugins.precision.transformer_engine
if module._TRANSFORMER_ENGINE_AVAILABLE:
pytest.skip("Assumes transformer_engine is unavailable")
Expand All @@ -32,10 +33,44 @@ def test_transformer_engine_precision_plugin(monkeypatch):

connector = _AcceleratorConnector(precision="transformer-engine")
assert isinstance(connector.precision_plugin, TransformerEnginePrecision)
assert connector.precision_plugin.dtype is torch.bfloat16
assert connector.precision_plugin.weights_dtype is torch.bfloat16
connector = _AcceleratorConnector(precision="transformer-engine-float16")
assert connector.precision_plugin.dtype is torch.float16
assert connector.precision_plugin.weights_dtype is torch.float16

precision = TransformerEnginePrecision()
precision = TransformerEnginePrecision(weights_dtype=torch.float32)
connector = _AcceleratorConnector(plugins=precision)
assert connector.precision_plugin is precision


def test_configure_model(monkeypatch):
module = lightning.fabric.plugins.precision.transformer_engine
if module._TRANSFORMER_ENGINE_AVAILABLE:
pytest.skip("Assumes transformer_engine is unavailable")
monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True)
te_mock = Mock()
te_mock.pytorch.fp8_autocast.return_value = nullcontext()

class ModuleMock(torch.nn.Linear):
def __init__(self, in_features, out_features, bias=True, *_, **__):
super().__init__(in_features, out_features, bias)

te_mock.pytorch.Linear = ModuleMock
monkeypatch.setitem(sys.modules, "transformer_engine", te_mock)
monkeypatch.setitem(sys.modules, "transformer_engine.pytorch", te_mock)
monkeypatch.setitem(sys.modules, "transformer_engine.common.recipe", te_mock)

class MyModel(LightningModule):
def configure_model(self):
self.l = torch.nn.Linear(8, 16)
assert self.l.weight.dtype == torch.float16

def test_step(self, *_):
...

model = MyModel()
trainer = Trainer(barebones=True, precision="transformer-engine-float16")
trainer.test(model, [0])
te_mock.pytorch.fp8_autocast.assert_called_once_with(enabled=True, fp8_recipe=ANY)
# TODO: invert condition once this gets fixed
assert not isinstance(model.l, ModuleMock)
assert model.l.weight.dtype == torch.float16

0 comments on commit 97469c6

Please sign in to comment.