diff --git a/docs/source-fabric/fundamentals/precision.rst b/docs/source-fabric/fundamentals/precision.rst index 85c1f311ac22a..49ff2c30a9e2c 100644 --- a/docs/source-fabric/fundamentals/precision.rst +++ b/docs/source-fabric/fundamentals/precision.rst @@ -158,7 +158,7 @@ the model and inputs can be kept in true full or half precision. from lightning.fabric.plugins import TransformerEnginePrecision recipe = {"fp8_format": "HYBRID", "amax_history_len": 16, "amax_compute_algo": "max"} - precision = TransformerEnginePrecision(dtype=torch.bfloat16, recipe=recipe) + precision = TransformerEnginePrecision(weights_dtype=torch.bfloat16, recipe=recipe) fabric = Fabric(plugins=precision) diff --git a/docs/source-pytorch/common/precision_intermediate.rst b/docs/source-pytorch/common/precision_intermediate.rst index 18fb8a509a3dd..eff5805497c2d 100644 --- a/docs/source-pytorch/common/precision_intermediate.rst +++ b/docs/source-pytorch/common/precision_intermediate.rst @@ -147,7 +147,7 @@ the model and inputs can be kept in true full or half precision. from lightning.trainer.plugins import TransformerEnginePrecision recipe = {"fp8_format": "HYBRID", "amax_history_len": 16, "amax_compute_algo": "max"} - precision = TransformerEnginePrecision(dtype=torch.bfloat16, recipe=recipe) + precision = TransformerEnginePrecision(weights_dtype=torch.bfloat16, recipe=recipe) trainer = Trainer(plugins=precision) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 07b0cdab58660..4bfcac952d1eb 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -15,11 +15,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `lightning.fabric.utilities.AttributeDict` for convenient dict-attribute access to represent state in script ([#18943](https://github.com/Lightning-AI/lightning/pull/18943)) +- Added `TransformerEnginePrecision(fallback_compute_dtype=)` to control the dtype of operations that don't support fp8 ([#19082](https://github.com/Lightning-AI/lightning/pull/19082)) + + ### Changed - `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846)) +- Changed the `TransformerEnginePrecision(dtype=)` argument to `weights_dtype` and made it required ([#19082](https://github.com/Lightning-AI/lightning/pull/19082)) + + ### Deprecated - @@ -38,6 +44,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed broadcast at initialization in `MPIEnvironment` ([#19074](https://github.com/Lightning-AI/lightning/pull/19074)) +- Fixed issue where the `precision="transformer-engine"` argument would not replace layers by default ([#19082](https://github.com/Lightning-AI/lightning/pull/19082)) + + ## [2.1.2] - 2023-11-15 diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index ddb91ba817755..f2e29379892d6 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -469,9 +469,9 @@ def _check_and_init_precision(self) -> Precision: if self._precision_input == "64-true": return DoublePrecision() if self._precision_input == "transformer-engine": - return TransformerEnginePrecision(dtype=torch.bfloat16) + return TransformerEnginePrecision(weights_dtype=torch.bfloat16) if self._precision_input == "transformer-engine-float16": - return TransformerEnginePrecision(dtype=torch.float16) + return TransformerEnginePrecision(weights_dtype=torch.float16) if self._precision_input == "16-mixed" and self._accelerator_flag == "cpu": rank_zero_warn( diff --git a/src/lightning/fabric/plugins/precision/transformer_engine.py b/src/lightning/fabric/plugins/precision/transformer_engine.py index 1aa4f66a27e8a..caeef6d32287c 100644 --- a/src/lightning/fabric/plugins/precision/transformer_engine.py +++ b/src/lightning/fabric/plugins/precision/transformer_engine.py @@ -26,7 +26,7 @@ _convert_fp_tensor, _DtypeContextManager, ) -from lightning.fabric.utilities.rank_zero import rank_zero_warn +from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn if TYPE_CHECKING: from transformer_engine.common.recipe import DelayedScaling @@ -42,13 +42,15 @@ class TransformerEnginePrecision(Precision): .. warning:: This is an :ref:`experimental ` feature. Args: - dtype: The weights dtype to use. + weights_dtype: The weights dtype to use. recipe: Recipe for the DelayedScaling `configuration `__. In dict format or the dataclass format. replace_layers: Whether to replace ``Linear`` and ``LayerNorm`` layers automatically with their Transformer Engine alternatives. Note that they don't subclass the torch equivalents so checks like ``isinstance(l, torch.nn.Linear)`` will not pass. + fallback_compute_dtype: The compute dtype to use for operations that don't support fp8 autocast. Defaults to the + same as ``weights_dtype``. .. note:: @@ -62,9 +64,11 @@ class TransformerEnginePrecision(Precision): def __init__( self, - dtype: Optional[torch.dtype] = None, + *, + weights_dtype: torch.dtype, recipe: Optional[Union[Mapping[str, Any], "DelayedScaling"]] = None, replace_layers: Optional[bool] = None, + fallback_compute_dtype: Optional[torch.dtype] = None, ) -> None: if not _TRANSFORMER_ENGINE_AVAILABLE: raise ModuleNotFoundError(str(_TRANSFORMER_ENGINE_AVAILABLE)) @@ -80,21 +84,27 @@ def __init__( recipe["fp8_format"] = getattr(Format, recipe["fp8_format"]) recipe = DelayedScaling(**recipe) - if dtype is None: - dtype = torch.get_default_dtype() - self.dtype = dtype + self.weights_dtype = weights_dtype self.recipe = recipe self.replace_layers = replace_layers + self.fallback_compute_dtype = fallback_compute_dtype or weights_dtype def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: # avoid converting if any is found. assume the user took care of it - if self.replace_layers and not any("transformer_engine.pytorch" in m.__module__ for m in module.modules()): + if any("transformer_engine.pytorch" in m.__module__ for m in module.modules()): + if self.replace_layers is True: + # info level because this is expected with `init_module` + rank_zero_info( + "`TransformerEnginePrecision(replace_layers=True)` is set but the model already contains" + " TransformerEngine layers. Skipping" + ) + elif self.replace_layers in (None, True): _convert_layers(module) - module = module.to(dtype=self.dtype) + module = module.to(dtype=self.weights_dtype) return module def tensor_init_context(self) -> ContextManager: - return _DtypeContextManager(self.dtype) + return _DtypeContextManager(self.weights_dtype) def module_init_context(self) -> ContextManager: dtype_ctx = self.tensor_init_context() @@ -113,17 +123,20 @@ def module_init_context(self) -> ContextManager: return stack def forward_context(self) -> ContextManager: - dtype_ctx = _DtypeContextManager(self.dtype) + dtype_ctx = _DtypeContextManager(self.weights_dtype) + fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype) import transformer_engine.pytorch as te autocast_ctx = te.fp8_autocast(enabled=True, fp8_recipe=self.recipe) stack = ExitStack() stack.enter_context(dtype_ctx) + # enable an outer fallback autocast for operations that do not support fp8 + stack.enter_context(fallback_autocast_ctx) stack.enter_context(autocast_ctx) return stack def convert_input(self, data: Any) -> Any: - return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.dtype) + return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.weights_dtype) def convert_output(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index b60c9a2649de7..2ed92e125945c 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The Trainer now restores the training mode set through `.train()` or `.eval()` on a submodule-level when switching from validation to training ([#18951](https://github.com/Lightning-AI/lightning/pull/18951)) +- Added `TransformerEnginePrecision(fallback_compute_dtype=)` to control the dtype of operations that don't support fp8 ([#19082](https://github.com/Lightning-AI/lightning/pull/19082)) + + ### Changed - `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846)) @@ -32,6 +35,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The `LightningModule.load_from_checkpoint()` function now calls `.configure_model()` on the model if it is overridden, to ensure all layers can be loaded from the checkpoint ([#19036](https://github.com/Lightning-AI/lightning/pull/19036)) +- Changed the `TransformerEnginePrecision(dtype=)` argument to `weights_dtype` and made it required ([#19082](https://github.com/Lightning-AI/lightning/pull/19082)) + + ### Deprecated - Deprecated all precision plugin classes under `lightning.pytorch.plugins` with the suffix `Plugin` in the name ([#18840](https://github.com/Lightning-AI/lightning/pull/18840)) @@ -65,6 +71,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed broadcast at initialization in `MPIEnvironment` ([#19074](https://github.com/Lightning-AI/lightning/pull/19074)) +- Fixed issue where the `precision="transformer-engine"` argument would not replace layers by default ([#19082](https://github.com/Lightning-AI/lightning/pull/19082)) + + ## [2.1.2] - 2023-11-15 ### Fixed diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index f4e2a1ced7e93..a7679f3e1cc3c 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -549,9 +549,9 @@ def _check_and_init_precision(self) -> Precision: if self._precision_flag == "64-true": return DoublePrecision() if self._precision_flag == "transformer-engine": - return TransformerEnginePrecision(dtype=torch.bfloat16) + return TransformerEnginePrecision(weights_dtype=torch.bfloat16) if self._precision_flag == "transformer-engine-float16": - return TransformerEnginePrecision(dtype=torch.float16) + return TransformerEnginePrecision(weights_dtype=torch.float16) if self._precision_flag == "16-mixed" and self._accelerator_flag == "cpu": rank_zero_warn( diff --git a/tests/tests_fabric/plugins/precision/test_transformer_engine.py b/tests/tests_fabric/plugins/precision/test_transformer_engine.py index b44df5233b453..63ae0bbb3a822 100644 --- a/tests/tests_fabric/plugins/precision/test_transformer_engine.py +++ b/tests/tests_fabric/plugins/precision/test_transformer_engine.py @@ -35,20 +35,20 @@ def test_transformer_engine_plugin(monkeypatch): connector = _Connector(precision="transformer-engine") assert isinstance(connector.precision, TransformerEnginePrecision) - assert connector.precision.dtype is torch.bfloat16 + assert connector.precision.weights_dtype is torch.bfloat16 connector = _Connector(precision="transformer-engine-float16") - assert connector.precision.dtype is torch.float16 + assert connector.precision.weights_dtype is torch.float16 recipe_mock.reset_mock() - precision = TransformerEnginePrecision() + precision = TransformerEnginePrecision(weights_dtype=torch.float32) connector = _Connector(plugins=precision) assert connector.precision is precision - assert precision.dtype == torch.float32 + assert precision.weights_dtype == torch.float32 recipe_mock.DelayedScaling.assert_called_once_with() recipe_mock.reset_mock() recipe = {"foo": 0, "fp8_format": "HYBRID"} - precision = TransformerEnginePrecision(dtype=torch.float16, recipe=recipe) + precision = TransformerEnginePrecision(weights_dtype=torch.float16, recipe=recipe) connector = _Connector(plugins=precision) assert connector.precision is precision recipe_mock.DelayedScaling.assert_called_once_with(foo=0, fp8_format=recipe_mock.Format.HYBRID) diff --git a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py index b442fe93853af..96fdd88847f36 100644 --- a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py +++ b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py @@ -125,7 +125,7 @@ def test_transformer_engine_precision_plugin(monkeypatch): from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecisionPlugin with pytest.deprecated_call(match=r"The `TransformerEnginePrecisionPlugin` is deprecated"): - TransformerEnginePrecisionPlugin() + TransformerEnginePrecisionPlugin(weights_dtype=torch.float32) def test_xla_precision_plugin(xla_available): diff --git a/tests/tests_pytorch/plugins/precision/test_transformer_engine.py b/tests/tests_pytorch/plugins/precision/test_transformer_engine.py index cf9e79a53ad5e..bbec19be61b1f 100644 --- a/tests/tests_pytorch/plugins/precision/test_transformer_engine.py +++ b/tests/tests_pytorch/plugins/precision/test_transformer_engine.py @@ -12,17 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License import sys -from unittest.mock import Mock +from contextlib import nullcontext +from unittest.mock import ANY, Mock +import lightning.fabric import pytest import torch +from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.plugins import TransformerEnginePrecision from lightning.pytorch.trainer.connectors.accelerator_connector import _AcceleratorConnector def test_transformer_engine_precision_plugin(monkeypatch): - import lightning.fabric # avoid breakage with standalone package - module = lightning.fabric.plugins.precision.transformer_engine if module._TRANSFORMER_ENGINE_AVAILABLE: pytest.skip("Assumes transformer_engine is unavailable") @@ -32,10 +33,44 @@ def test_transformer_engine_precision_plugin(monkeypatch): connector = _AcceleratorConnector(precision="transformer-engine") assert isinstance(connector.precision_plugin, TransformerEnginePrecision) - assert connector.precision_plugin.dtype is torch.bfloat16 + assert connector.precision_plugin.weights_dtype is torch.bfloat16 connector = _AcceleratorConnector(precision="transformer-engine-float16") - assert connector.precision_plugin.dtype is torch.float16 + assert connector.precision_plugin.weights_dtype is torch.float16 - precision = TransformerEnginePrecision() + precision = TransformerEnginePrecision(weights_dtype=torch.float32) connector = _AcceleratorConnector(plugins=precision) assert connector.precision_plugin is precision + + +def test_configure_model(monkeypatch): + module = lightning.fabric.plugins.precision.transformer_engine + if module._TRANSFORMER_ENGINE_AVAILABLE: + pytest.skip("Assumes transformer_engine is unavailable") + monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True) + te_mock = Mock() + te_mock.pytorch.fp8_autocast.return_value = nullcontext() + + class ModuleMock(torch.nn.Linear): + def __init__(self, in_features, out_features, bias=True, *_, **__): + super().__init__(in_features, out_features, bias) + + te_mock.pytorch.Linear = ModuleMock + monkeypatch.setitem(sys.modules, "transformer_engine", te_mock) + monkeypatch.setitem(sys.modules, "transformer_engine.pytorch", te_mock) + monkeypatch.setitem(sys.modules, "transformer_engine.common.recipe", te_mock) + + class MyModel(LightningModule): + def configure_model(self): + self.l = torch.nn.Linear(8, 16) + assert self.l.weight.dtype == torch.float16 + + def test_step(self, *_): + ... + + model = MyModel() + trainer = Trainer(barebones=True, precision="transformer-engine-float16") + trainer.test(model, [0]) + te_mock.pytorch.fp8_autocast.assert_called_once_with(enabled=True, fp8_recipe=ANY) + # TODO: invert condition once this gets fixed + assert not isinstance(model.l, ModuleMock) + assert model.l.weight.dtype == torch.float16