diff --git a/src/lightning_lite/connector.py b/src/lightning_lite/connector.py index 7b8d4a9330df8..3e9a7560d6472 100644 --- a/src/lightning_lite/connector.py +++ b/src/lightning_lite/connector.py @@ -16,6 +16,7 @@ from typing import Dict, List, Optional, Union import torch +from typing_extensions import Literal from lightning_lite.accelerators import ACCELERATOR_REGISTRY from lightning_lite.accelerators.accelerator import Accelerator @@ -58,6 +59,7 @@ _PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO] _PLUGIN_INPUT = Union[_PLUGIN, str] +_PRECISION_INPUT = Literal[16, 32, 64, "bf16"] class _Connector: @@ -97,7 +99,7 @@ def __init__( strategy: Optional[Union[str, Strategy]] = None, devices: Optional[Union[List[int], str, int]] = None, num_nodes: int = 1, - precision: Union[int, str] = 32, + precision: _PRECISION_INPUT = 32, plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, ) -> None: # 1. Parsing flags @@ -111,7 +113,7 @@ def __init__( # For devices: Assign gpus, ipus, etc. to the accelerator flag and devices flag self._strategy_flag: Optional[Union[Strategy, str]] = None self._accelerator_flag: Optional[Union[Accelerator, str]] = None - self._precision_flag: Optional[Union[int, str]] = None + self._precision_flag: Optional[_PRECISION_INPUT] = None self._precision_plugin_flag: Optional[Precision] = None self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None self._parallel_devices: List[Union[int, torch.device, str]] = [] @@ -154,7 +156,7 @@ def _check_config_and_set_final_flags( self, strategy: Optional[Union[str, Strategy]], accelerator: Optional[Union[str, Accelerator]], - precision: Union[int, str], + precision: _PRECISION_INPUT, plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]], ) -> None: """This method checks: diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index 414c63c841813..04b964e41c5a0 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -27,7 +27,7 @@ from torch.utils.data import BatchSampler, DataLoader, DistributedSampler from lightning_lite.accelerators.accelerator import Accelerator -from lightning_lite.connector import _Connector, _PLUGIN_INPUT +from lightning_lite.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT from lightning_lite.plugins import Precision from lightning_lite.strategies import DeepSpeedStrategy, Strategy, XLAStrategy from lightning_lite.strategies.strategy import TBroadcast @@ -74,7 +74,7 @@ def __init__( strategy: Optional[Union[str, Strategy]] = None, devices: Optional[Union[List[int], str, int]] = None, num_nodes: int = 1, - precision: Union[int, str] = 32, + precision: _PRECISION_INPUT = 32, plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, ) -> None: self._connector = _Connector( diff --git a/src/lightning_lite/plugins/precision/deepspeed.py b/src/lightning_lite/plugins/precision/deepspeed.py index 3ccf2d3ba09b4..3817bf3aa3f04 100644 --- a/src/lightning_lite/plugins/precision/deepspeed.py +++ b/src/lightning_lite/plugins/precision/deepspeed.py @@ -11,12 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING +import torch from lightning_utilities.core.imports import RequirementCache from torch import Tensor +from typing_extensions import Literal from lightning_lite.plugins.precision.precision import Precision +from lightning_lite.plugins.precision.utils import _convert_fp_tensor from lightning_lite.utilities.enums import AMPType, PrecisionType from lightning_lite.utilities.imports import _APEX_AVAILABLE from lightning_lite.utilities.types import Steppable @@ -43,7 +46,7 @@ class DeepSpeedPrecision(Precision): If unsupported ``precision`` is provided. """ - def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optional[str] = None) -> None: + def __init__(self, precision: Literal[16, 32, "bf16"], amp_type: str, amp_level: Optional[str] = None) -> None: if amp_type == AMPType.APEX: if not _APEX_AVAILABLE: raise ImportError( @@ -65,6 +68,11 @@ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optiona self.amp_type = amp_type self.amp_level = amp_level + def convert_input(self, data: Tensor) -> Tensor: + precision_to_type = {"bf16": torch.bfloat16, 16: torch.float16, 32: torch.float32} + dst_type = precision_to_type[self.precision] + return _convert_fp_tensor(data, dst_type) + def backward(self, tensor: Tensor, model: "deepspeed.DeepSpeedEngine", *args: Any, **kwargs: Any) -> None: """Performs back-propagation using DeepSpeed's engine.""" model.backward(tensor, *args, **kwargs) diff --git a/src/lightning_lite/plugins/precision/double.py b/src/lightning_lite/plugins/precision/double.py index 9bbb033b742a9..dd0aa73eee5f0 100644 --- a/src/lightning_lite/plugins/precision/double.py +++ b/src/lightning_lite/plugins/precision/double.py @@ -15,9 +15,11 @@ from typing import Generator import torch +from torch import Tensor from torch.nn import Module from lightning_lite.plugins.precision.precision import Precision +from lightning_lite.plugins.precision.utils import _convert_fp_tensor class DoublePrecision(Precision): @@ -38,3 +40,6 @@ def forward_context(self) -> Generator[None, None, None]: torch.set_default_dtype(torch.float64) yield torch.set_default_dtype(default_dtype) + + def convert_input(self, data: Tensor) -> Tensor: + return _convert_fp_tensor(data, torch.double) diff --git a/src/lightning_lite/plugins/precision/native_amp.py b/src/lightning_lite/plugins/precision/native_amp.py index 58cdffeb99cbc..34b4fb5591724 100644 --- a/src/lightning_lite/plugins/precision/native_amp.py +++ b/src/lightning_lite/plugins/precision/native_amp.py @@ -18,8 +18,10 @@ from torch import Tensor from torch.nn import Module from torch.optim import LBFGS +from typing_extensions import Literal from lightning_lite.plugins.precision.precision import Precision +from lightning_lite.plugins.precision.utils import _convert_fp_tensor from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10 from lightning_lite.utilities.types import Steppable @@ -39,7 +41,7 @@ class NativeMixedPrecision(Precision): """ def __init__( - self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None + self, precision: Literal[16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None ) -> None: super().__init__() if precision == "bf16" and not _TORCH_GREATER_EQUAL_1_10: @@ -57,6 +59,11 @@ def forward_context(self) -> Generator[None, None, None]: with self._autocast_context_manager(): yield + def convert_input(self, data: Tensor) -> Tensor: + precision_to_type = {"bf16": torch.bfloat16, 16: torch.float16} + dst_type = precision_to_type[self.precision] + return _convert_fp_tensor(data, dst_type) + def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None: if self.scaler is not None: tensor = self.scaler.scale(tensor) diff --git a/src/lightning_lite/plugins/precision/precision.py b/src/lightning_lite/plugins/precision/precision.py index 07ce77b8897bf..6a3883a022edb 100644 --- a/src/lightning_lite/plugins/precision/precision.py +++ b/src/lightning_lite/plugins/precision/precision.py @@ -14,10 +14,12 @@ import contextlib from typing import Any, Dict, Generator, Optional, Union +import torch from torch import Tensor from torch.nn import Module from torch.optim import Optimizer +from lightning_lite.plugins.precision.utils import _convert_fp_tensor from lightning_lite.utilities.types import _PARAMETERS, Steppable @@ -41,6 +43,13 @@ def forward_context(self) -> Generator[None, None, None]: """A contextmanager for managing model forward/training_step/evaluation_step/predict_step.""" yield + def convert_input(self, data: Tensor) -> Tensor: + """Convert model inputs (forward) to the floating point precision type of this plugin. + + This is a no-op for tensors that are not of floating-point type or already have the desired type. + """ + return _convert_fp_tensor(data, torch.float32) + def pre_backward(self, tensor: Tensor, module: Optional[Module]) -> Any: """Runs before precision plugin executes backward. diff --git a/src/lightning_lite/plugins/precision/tpu_bf16.py b/src/lightning_lite/plugins/precision/tpu_bf16.py index d388a9ae175ac..84824b35e4e0d 100644 --- a/src/lightning_lite/plugins/precision/tpu_bf16.py +++ b/src/lightning_lite/plugins/precision/tpu_bf16.py @@ -13,7 +13,11 @@ # limitations under the License. import os +import torch +from torch import Tensor + from lightning_lite.plugins.precision import TPUPrecision +from lightning_lite.plugins.precision.utils import _convert_fp_tensor class TPUBf16Precision(TPUPrecision): @@ -25,5 +29,8 @@ def __init__(self) -> None: super().__init__() os.environ["XLA_USE_BF16"] = "1" + def convert_input(self, data: Tensor) -> Tensor: + return _convert_fp_tensor(data, torch.bfloat16) + def teardown(self) -> None: os.environ.pop("XLA_USE_BF16", None) diff --git a/src/lightning_lite/plugins/precision/utils.py b/src/lightning_lite/plugins/precision/utils.py index f607755e5da7d..dc41a5202d817 100644 --- a/src/lightning_lite/plugins/precision/utils.py +++ b/src/lightning_lite/plugins/precision/utils.py @@ -15,16 +15,6 @@ import torch -from lightning_lite.utilities.enums import PrecisionType - - -def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor: - if precision == PrecisionType.HALF: - return _convert_fp_tensor(tensor, torch.half) - if precision == PrecisionType.BFLOAT: - return _convert_fp_tensor(tensor, torch.bfloat16) - return tensor - def _convert_fp_tensor(tensor: torch.Tensor, dst_type: Union[str, torch.dtype]) -> torch.Tensor: return tensor.to(dst_type) if torch.is_floating_point(tensor) else tensor diff --git a/src/lightning_lite/strategies/deepspeed.py b/src/lightning_lite/strategies/deepspeed.py index facc9f5cbd3ef..da532c5a567fd 100644 --- a/src/lightning_lite/strategies/deepspeed.py +++ b/src/lightning_lite/strategies/deepspeed.py @@ -23,16 +23,13 @@ import torch from lightning_utilities.core.imports import RequirementCache from lightning_utilities.core.rank_zero import rank_zero_only -from torch import Tensor from torch.nn import Module from torch.optim import Optimizer from lightning_lite.accelerators import Accelerator, CUDAAccelerator from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment from lightning_lite.plugins.precision import Precision -from lightning_lite.plugins.precision.utils import _fp_to_half from lightning_lite.strategies.ddp import DDPStrategy -from lightning_lite.utilities.apply_func import apply_to_collection from lightning_lite.utilities.distributed import log from lightning_lite.utilities.enums import AMPType, PrecisionType from lightning_lite.utilities.rank_zero import rank_zero_info @@ -366,10 +363,6 @@ def load_module_state_dict(self, module: Module, checkpoint: Mapping[str, Any]) self.module_to_device(module) self._restore_zero_state(module, checkpoint) - def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any: - batch = apply_to_collection(batch, Tensor, function=_fp_to_half, precision=self.precision_plugin.precision) - return super().batch_to_device(batch, device) - @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register("deepspeed", cls, description="Default DeepSpeed Strategy") diff --git a/src/lightning_lite/wrappers.py b/src/lightning_lite/wrappers.py index a976edb09a10b..c13f9a450b054 100644 --- a/src/lightning_lite/wrappers.py +++ b/src/lightning_lite/wrappers.py @@ -97,23 +97,13 @@ def module(self) -> nn.Module: def forward(self, *args: Any, **kwargs: Any) -> Any: """Casts all inputs to the right precision and handles autocast for operations in the module forward method.""" - precision = self._precision_plugin.precision - precision_to_type = { - "bf16": torch.bfloat16, - 16: torch.float16, - 32: torch.float32, - 64: torch.float64, - } - # TODO: let the precision plugin handle the conversion - args, kwargs = apply_to_collection( - [args, kwargs], dtype=Tensor, function=_convert_fp_tensor, dst_type=precision_to_type[precision] - ) + args, kwargs = apply_to_collection([args, kwargs], function=self._precision_plugin.convert_input, dtype=Tensor) with self._precision_plugin.forward_context(): output = self._forward_module(*args, **kwargs) output = apply_to_collection( - output, dtype=Tensor, function=_convert_fp_tensor, dst_type=torch.get_default_dtype() + output, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype() ) return output diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index c34561b702a25..be6f2108249d0 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -18,6 +18,7 @@ from lightning_utilities.core.rank_zero import rank_zero_deprecation, rank_zero_warn from lightning_lite.connector import _PLUGIN_INPUT as _LITE_PLUGIN_INPUT +from lightning_lite.connector import _PRECISION_INPUT from lightning_lite.lite import LightningLite as _NewLightningLite from lightning_lite.plugins import CheckpointIO, ClusterEnvironment from lightning_lite.plugins import DeepSpeedPrecision as LiteDeepSpeedPrecision @@ -98,7 +99,7 @@ def __init__( strategy: Optional[Union[str, PLStrategy]] = None, devices: Optional[Union[List[int], str, int]] = None, num_nodes: int = 1, - precision: Union[int, str] = 32, + precision: _PRECISION_INPUT = 32, plugins: Optional[Union[_PL_PLUGIN_INPUT, List[_PL_PLUGIN_INPUT]]] = None, gpus: Optional[Union[List[int], str, int]] = None, tpu_cores: Optional[Union[List[int], str, int]] = None, @@ -286,13 +287,17 @@ def _to_lite_precision_plugin(plugin: Optional[PLPrecisionPlugin]) -> LitePrecis return LitePrecision() if type(plugin) is PLNativeMixedPrecisionPlugin: - return LiteNativeMixedPrecision(precision=plugin.precision, device=plugin.device, scaler=plugin.scaler) + return LiteNativeMixedPrecision( + precision=plugin.precision, device=plugin.device, scaler=plugin.scaler # type: ignore[arg-type] + ) if type(plugin) is PLDoublePrecisionPlugin: return LiteDoublePrecision() if type(plugin) is PLDeepSpeedPrecisionPlugin: - return LiteDeepSpeedPrecision(precision=plugin.precision, amp_type=plugin.amp_type, amp_level=plugin.amp_level) + return LiteDeepSpeedPrecision( + precision=plugin.precision, amp_type=plugin.amp_type, amp_level=plugin.amp_level # type: ignore[arg-type] + ) if type(plugin) is PLTPUPrecisionPlugin: return LiteTPUPrecision() diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index a3a6a7998e546..82f0aed9d1366 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -31,7 +31,6 @@ import pytorch_lightning as pl from lightning_lite.plugins import ClusterEnvironment -from lightning_lite.plugins.precision.utils import _fp_to_half from lightning_lite.utilities.enums import AMPType, PrecisionType from lightning_lite.utilities.optimizer import optimizers_to_device from lightning_lite.utilities.seed import reset_seed @@ -41,6 +40,7 @@ from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp import DDPStrategy +from pytorch_lightning.strategies.utils import _fp_to_half from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index a5daa13dc9de6..66c11b26c90ff 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -22,12 +22,12 @@ import pytorch_lightning as pl from lightning_lite.plugins import CheckpointIO, ClusterEnvironment -from lightning_lite.plugins.precision.utils import _fp_to_half from lightning_lite.utilities.cloud_io import get_filesystem from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.strategies.strategy import TBroadcast +from pytorch_lightning.strategies.utils import _fp_to_half from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.data import _get_dataloader_init_args_and_kwargs, _reinstantiate_wrapped_cls diff --git a/src/pytorch_lightning/strategies/utils.py b/src/pytorch_lightning/strategies/utils.py index fa360d3770cd7..be602a3929665 100644 --- a/src/pytorch_lightning/strategies/utils.py +++ b/src/pytorch_lightning/strategies/utils.py @@ -15,7 +15,11 @@ import os from inspect import getmembers, isclass +import torch + +from lightning_lite.plugins.precision.utils import _convert_fp_tensor from lightning_lite.strategies import _StrategyRegistry +from lightning_lite.utilities.enums import PrecisionType from lightning_lite.utilities.registry import _is_register_method_overridden from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation @@ -34,3 +38,11 @@ def _call_register_strategies(registry: _StrategyRegistry, base_module: str) -> for _, mod in getmembers(module, isclass): if issubclass(mod, Strategy) and _is_register_method_overridden(mod, Strategy, "register_strategies"): mod.register_strategies(registry) + + +def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor: + if precision == PrecisionType.HALF: + return _convert_fp_tensor(tensor, torch.half) + if precision == PrecisionType.BFLOAT: + return _convert_fp_tensor(tensor, torch.bfloat16) + return tensor diff --git a/tests/tests_lite/plugins/precision/test_native_amp_integration.py b/tests/tests_lite/plugins/precision/test_native_amp_integration.py index f657236069342..94d8c399679cf 100644 --- a/tests/tests_lite/plugins/precision/test_native_amp_integration.py +++ b/tests/tests_lite/plugins/precision/test_native_amp_integration.py @@ -67,6 +67,6 @@ def after_backward(self, model): ], ) def test_native_mixed_precision(accelerator, precision, expected_dtype): - lite = NativeMixedPrecisionBoringLite(accelerator=accelerator, precision=16) + lite = NativeMixedPrecisionBoringLite(accelerator=accelerator, precision=precision) lite.expected_dtype = expected_dtype lite.run()