From de2281dfbb97c16eecb81a8b149c9b9e250e79eb Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Thu, 16 Feb 2023 15:06:19 +0100 Subject: [PATCH] update src code --- .../pytorch/plugins/precision/amp.py | 12 ++-- .../pytorch/plugins/precision/deepspeed.py | 11 ++-- .../pytorch/plugins/precision/double.py | 2 +- .../pytorch/plugins/precision/fsdp.py | 8 +-- .../pytorch/plugins/precision/hpu.py | 10 ++-- .../pytorch/plugins/precision/ipu.py | 13 ++--- .../pytorch/plugins/precision/tpu_bf16.py | 2 +- src/lightning/pytorch/strategies/deepspeed.py | 12 ++-- src/lightning/pytorch/strategies/fsdp.py | 4 +- .../connectors/accelerator_connector.py | 55 +++++++++++-------- src/lightning/pytorch/trainer/trainer.py | 7 ++- 11 files changed, 70 insertions(+), 66 deletions(-) diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index ce984070ae7a5..bddcf2b480a75 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -34,15 +34,15 @@ class MixedPrecisionPlugin(PrecisionPlugin): """ def __init__( - self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None + self, precision: Literal["16-mixed", "bf16-mixed"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None ) -> None: - self.precision = cast(Literal["16", "bf16"], str(precision)) - if scaler is None and self.precision == "16": + self.precision = cast(Literal["16-mixed", "bf16-mixed"], str(precision)) + if scaler is None and self.precision == "16-mixed": with _patch_cuda_is_available(): # if possible, we defer CUDA initialization to support strategies that will attempt forks scaler = torch.cuda.amp.GradScaler() - if scaler is not None and self.precision == "bf16": - raise MisconfigurationException(f"`precision='bf16'` does not use a scaler, found {scaler}.") + if scaler is not None and self.precision == "bf16-mixed": + raise MisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.") self.device = device self.scaler = scaler @@ -97,7 +97,7 @@ def clip_gradients( def autocast_context_manager(self) -> torch.autocast: # the dtype could be automatically inferred but we need to manually set it due to a bug upstream # https://github.com/pytorch/pytorch/issues/67233 - return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half) + return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16-mixed" else torch.half) @contextmanager def forward_context(self) -> Generator[None, None, None]: diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py index 20f5748aa444e..99f8d06173fa8 100644 --- a/src/lightning/pytorch/plugins/precision/deepspeed.py +++ b/src/lightning/pytorch/plugins/precision/deepspeed.py @@ -31,9 +31,8 @@ warning_cache = WarningCache() -_PRECISION_INPUT_INT = Literal[32, 16] -_PRECISION_INPUT_STR = Literal["32", "16", "bf16"] -_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR] +_PRECISION_INPUT = Literal["32-true", "16-mixed", "bf16-mixed"] + class DeepSpeedPrecisionPlugin(PrecisionPlugin): @@ -46,14 +45,14 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin): If unsupported ``precision`` is provided. """ - def __init__(self, precision: Literal["32", 32, "16", 16, "bf16"]) -> None: - supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + def __init__(self, precision: Literal["32-true", "16-mixed", "bf16-mixed"]) -> None: + supported_precision = get_args(_PRECISION_INPUT) if precision not in supported_precision: raise ValueError( f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported." f" `precision` must be one of: {supported_precision}." ) - self.precision = cast(_PRECISION_INPUT_STR, str(precision)) + self.precision = cast(_PRECISION_INPUT, str(precision)) def backward( # type: ignore[override] self, diff --git a/src/lightning/pytorch/plugins/precision/double.py b/src/lightning/pytorch/plugins/precision/double.py index e008097046637..77fa9c4171a2b 100644 --- a/src/lightning/pytorch/plugins/precision/double.py +++ b/src/lightning/pytorch/plugins/precision/double.py @@ -72,7 +72,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: class DoublePrecisionPlugin(PrecisionPlugin): """Plugin for training with double (``torch.float64``) precision.""" - precision: Literal["64"] = "64" + precision: Literal["64-true"] = "64-true" def connect( self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any] diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index 7e1d6a5250294..1561bd693f037 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -31,12 +31,12 @@ class FSDPMixedPrecisionPlugin(MixedPrecisionPlugin): """AMP for Fully Sharded Data Parallel (FSDP) Training.""" def __init__( - self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[ShardedGradScaler] = None + self, precision: Literal["16-mixed", "bf16-mixed"], device: str, scaler: Optional[ShardedGradScaler] = None ) -> None: if not _TORCH_GREATER_EQUAL_1_12: raise MisconfigurationException("`FSDPMixedPrecisionPlugin` is supported from PyTorch v1.12.0 onwards.") super().__init__( - precision, device, scaler=(ShardedGradScaler() if scaler is None and str(precision) == "16" else None) + precision, device, scaler=(ShardedGradScaler() if scaler is None and str(precision) == "16-mixed" else None) ) def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: @@ -52,9 +52,9 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: @property def mixed_precision_config(self) -> Optional[MixedPrecision]: assert MixedPrecision is not None - if self.precision == "16": + if self.precision == "16-mixed": dtype = torch.float16 - elif self.precision == "bf16": + elif self.precision == "bf16-mixed": dtype = torch.bfloat16 else: raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.") diff --git a/src/lightning/pytorch/plugins/precision/hpu.py b/src/lightning/pytorch/plugins/precision/hpu.py index 8d805deae1da8..65863e9f77d22 100644 --- a/src/lightning/pytorch/plugins/precision/hpu.py +++ b/src/lightning/pytorch/plugins/precision/hpu.py @@ -22,9 +22,7 @@ if _HPU_AVAILABLE: from habana_frameworks.torch.hpex import hmp -_PRECISION_INPUT_INT = Literal[32, 16] -_PRECISION_INPUT_STR = Literal["32", "16", "bf16"] -_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR] +_PRECISION_INPUT = Literal["32-true", "16-mixed", "bf16-mixed"] class HPUPrecisionPlugin(PrecisionPlugin): @@ -48,14 +46,14 @@ def __init__( ) -> None: if not _HPU_AVAILABLE: raise MisconfigurationException("HPU precision plugin requires HPU devices.") - supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + supported_precision = get_args(_PRECISION_INPUT) if precision not in supported_precision: raise ValueError( f"`Trainer(accelerator='hpu', precision={precision!r})` is not supported." f" `precision` must be one of: {supported_precision}." ) - self.precision = cast(_PRECISION_INPUT_STR, str(precision)) - if self.precision in ("16", "bf16"): + self.precision = cast(_PRECISION_INPUT, str(precision)) + if self.precision in ("16-mixed", "bf16-mixed"): hmp.convert( opt_level=opt_level, bf16_file_path=bf16_file_path, fp32_file_path=fp32_file_path, isVerbose=verbose ) diff --git a/src/lightning/pytorch/plugins/precision/ipu.py b/src/lightning/pytorch/plugins/precision/ipu.py index f82bc07ac2119..95737fb996d49 100644 --- a/src/lightning/pytorch/plugins/precision/ipu.py +++ b/src/lightning/pytorch/plugins/precision/ipu.py @@ -27,27 +27,24 @@ warning_cache = WarningCache() -_PRECISION_INPUT_INT = Literal[32, 16] -_PRECISION_INPUT_STR = Literal["32", "16"] -_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR] - +_PRECISION_INPUT = Literal["32-true", "16-mixed"] class IPUPrecisionPlugin(PrecisionPlugin): """Precision plugin for IPU integration. Raises: ValueError: - If the precision is neither 16 nor 32. + If the precision is neither 16-mixed nor 32-true. """ - def __init__(self, precision: Literal["32", 32, "16", 16]) -> None: - supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + def __init__(self, precision: Literal["32-true", "16-mixed"]) -> None: + supported_precision = get_args(_PRECISION_INPUT) if precision not in supported_precision: raise ValueError( f"`Trainer(accelerator='ipu', precision={precision!r})` is not supported." f" `precision` must be one of: {supported_precision}." ) - self.precision = cast(_PRECISION_INPUT_STR, str(precision)) + self.precision = cast(_PRECISION_INPUT, str(precision)) def backward( # type: ignore[override] self, diff --git a/src/lightning/pytorch/plugins/precision/tpu_bf16.py b/src/lightning/pytorch/plugins/precision/tpu_bf16.py index 814af173d0464..0ff320d1ea807 100644 --- a/src/lightning/pytorch/plugins/precision/tpu_bf16.py +++ b/src/lightning/pytorch/plugins/precision/tpu_bf16.py @@ -23,7 +23,7 @@ class TPUBf16PrecisionPlugin(TPUPrecisionPlugin): """Plugin that enables bfloats on TPUs.""" - precision: Literal["bf16"] = "bf16" + precision: Literal["bf16"] = "bf16-mixed" def connect( self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any] diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 4b7a806cbf95e..7778cab07c47f 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -127,8 +127,8 @@ def __init__( Arguments: - zero_optimization: Enable ZeRO optimization. This is compatible with either `precision=16` or - `precision="bf16"`. + zero_optimization: Enable ZeRO optimization. This is compatible with either `precision="16-mixed"` or + `precision="bf16-mixed"`. stage: Different stages of the ZeRO Optimizer. 0 is disabled, 1 is optimizer state partitioning, 2 is optimizer+gradient state partitioning, @@ -505,9 +505,9 @@ def model_sharded_context(self) -> Generator[None, None, None]: if self.zero_stage_3: assert self._config_initialized - if self.precision_plugin.precision == "16": + if self.precision_plugin.precision == "16-mixed": dtype = torch.float16 - elif self.precision_plugin.precision == "bf16": + elif self.precision_plugin.precision == "bf16-mixed": dtype = torch.bfloat16 else: dtype = torch.float32 @@ -641,7 +641,7 @@ def _auto_select_batch_size(self) -> int: def _format_precision_config(self) -> None: assert isinstance(self.config, dict) - if self.precision_plugin.precision == "16": + if self.precision_plugin.precision == "16-mixed": if "fp16" not in self.config: # FP16 is a DeepSpeed standalone AMP implementation rank_zero_info("Enabling DeepSpeed FP16.") @@ -653,7 +653,7 @@ def _format_precision_config(self) -> None: "hysteresis": self.hysteresis, "min_loss_scale": self.min_loss_scale, } - elif "bf16" not in self.config and self.precision_plugin.precision == "bf16": + elif "bf16" not in self.config and self.precision_plugin.precision == "bf16-mixed": rank_zero_info("Enabling DeepSpeed BF16.") self.config["bf16"] = {"enabled": True} diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 0ac1709ad3680..f58dfc1db90f8 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -99,8 +99,8 @@ class FSDPStrategy(ParallelStrategy): algorithms to help backward communication and computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. mixed_precision: - Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` - or BF16 if ``precision=bf16`` unless a config is passed in. + Mixed Precision config. By default, Lightning will enable FP16 if ``precision="16-mixed"`` + or BF16 if ``precision="bf16-mixed"`` unless a config is passed in. This is only available in PyTorch 1.12 and later. activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation checkpointing. This is typically your transformer block (including attention + feed-forward). diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index d803ec58c25e5..4185efe87407c 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -71,13 +71,11 @@ from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _LIGHTNING_COLOSSALAI_AVAILABLE from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn +from lightning.fabric.connector import _PRECISION_INPUT, _PRECISION_INPUT_STR_LEGACY, _PRECISION_INPUT_STR_LEGACY_CONVERSION, _PRECISION_INPUT_STR, _PRECISION_INPUT_INT log = logging.getLogger(__name__) _LITERAL_WARN = Literal["warn"] -_PRECISION_INPUT_INT = Literal[64, 32, 16] -_PRECISION_INPUT_STR = Literal["64", "32", "16", "bf16"] -_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR] class AcceleratorConnector: @@ -88,7 +86,7 @@ def __init__( accelerator: Optional[Union[str, Accelerator]] = None, strategy: Optional[Union[str, Strategy]] = None, plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None, - precision: _PRECISION_INPUT = 32, + precision: _PRECISION_INPUT = "32-true", sync_batchnorm: bool = False, benchmark: Optional[bool] = None, replace_sampler_ddp: bool = True, @@ -136,7 +134,7 @@ def __init__( # Set each valid flag to `self._x_flag` after validation self._strategy_flag: Optional[Union[Strategy, str]] = None self._accelerator_flag: Optional[Union[Accelerator, str]] = None - self._precision_flag: _PRECISION_INPUT_STR = "32" + self._precision_flag: _PRECISION_INPUT_STR = "32-true" self._precision_plugin_flag: Optional[PrecisionPlugin] = None self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None self._parallel_devices: List[Union[int, torch.device, str]] = [] @@ -243,12 +241,23 @@ def _check_config_and_set_final_flags( self._accelerator_flag = accelerator - supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + get_args(_PRECISION_INPUT_STR_LEGACY) if precision not in supported_precision: raise MisconfigurationException( f"Precision {repr(precision)} is invalid. Allowed precision values: {supported_precision}" ) - self._precision_flag = cast(_PRECISION_INPUT_STR, str(precision)) + + precision = str(precision) # convert int flags to str here to enable the legacy-conversion below + + if precision in get_args(_PRECISION_INPUT_STR_LEGACY): + if not str(precision)[:2] in ('32', '64'): + rank_zero_warn( + f"{precision} is supported for historical reasons but its usage is discouraged. " + f"Please set your precision to {_PRECISION_INPUT_STR_LEGACY_CONVERSION[precision]} instead!" + ) + precision = _PRECISION_INPUT_STR_LEGACY_CONVERSION[precision] + + self._precision_flag = cast(_PRECISION_INPUT_STR, precision) if plugins: plugins_flags_types: Dict[str, int] = Counter() @@ -518,13 +527,13 @@ def _check_and_init_precision(self) -> PrecisionPlugin: if isinstance(self.accelerator, HPUAccelerator): return HPUPrecisionPlugin(self._precision_flag) # type: ignore if isinstance(self.accelerator, TPUAccelerator): - if self._precision_flag == "32": + if self._precision_flag == "32-true": return TPUPrecisionPlugin() - elif self._precision_flag in ("16", "bf16"): - if self._precision_flag == "16": + elif self._precision_flag in ("16-mixed", "bf16-mixed"): + if self._precision_flag == "16-mixed": rank_zero_warn( - "You passed `Trainer(accelerator='tpu', precision=16)` but AMP" - " is not supported with TPUs. Using `precision='bf16'` instead." + "You passed `Trainer(accelerator='tpu', precision='16-mixed')` but AMP with fp16" + " is not supported on TPUs. Using `precision='bf16-mixed'` instead." ) return TPUBf16PrecisionPlugin() @@ -537,21 +546,21 @@ def _check_and_init_precision(self) -> PrecisionPlugin: if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecisionPlugin(self._precision_flag) - if self._precision_flag == "32": + if self._precision_flag == "32-true": return PrecisionPlugin() - if self._precision_flag == "64": + if self._precision_flag == "64-true": return DoublePrecisionPlugin() - if self._precision_flag == "16" and self._accelerator_flag == "cpu": + if self._precision_flag == "16-mixed" and self._accelerator_flag == "cpu": rank_zero_warn( - "You passed `Trainer(accelerator='cpu', precision=16)` but AMP is not supported on CPU." - " Using `precision='bf16'` instead." + "You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on " + "CPU. Using `precision='bf16-mixed'` instead." ) - self._precision_flag = "bf16" + self._precision_flag = "bf16-mixed" - if self._precision_flag in ("16", "bf16"): + if self._precision_flag in ("16-mixed", "bf16-mixed"): rank_zero_info( - f"Using {'16bit' if self._precision_flag == 16 else 'bfloat16'} Automatic Mixed Precision (AMP)" + f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)" ) device = "cpu" if self._accelerator_flag == "cpu" else "cuda" @@ -564,9 +573,9 @@ def _check_and_init_precision(self) -> PrecisionPlugin: def _validate_precision_choice(self) -> None: """Validate the combination of choices for precision, AMP type, and accelerator.""" if isinstance(self.accelerator, TPUAccelerator): - if self._precision_flag == "64": + if self._precision_flag == "64-true": raise MisconfigurationException( - "`Trainer(accelerator='tpu', precision=64)` is not implemented." + "`Trainer(accelerator='tpu', precision='64-true')` is not implemented." " Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`" " requesting this feature." ) @@ -578,7 +587,7 @@ def _validate_precision_choice(self) -> None: f" found: {self._precision_plugin_flag}." ) if isinstance(self.accelerator, HPUAccelerator): - if self._precision_flag not in ("16", "bf16", "32"): + if self._precision_flag not in ("16-mixed", "bf16-mixed", "32-true"): raise MisconfigurationException( f"`Trainer(accelerator='hpu', precision={self._precision_flag!r})` is not supported." ) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 2926982ecc94c..cf5f33c0d89d3 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -122,7 +122,7 @@ def __init__( accelerator: Optional[Union[str, Accelerator]] = None, strategy: Optional[Union[str, Strategy]] = None, sync_batchnorm: bool = False, - precision: _PRECISION_INPUT = 32, + precision: _PRECISION_INPUT = '32-true', enable_model_summary: bool = True, num_sanity_val_steps: int = 2, profiler: Optional[Union[Profiler, str]] = None, @@ -226,9 +226,10 @@ def __init__( plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins. Default: ``None``. - precision: Double precision (64), full precision (32), half precision (16) or bfloat16 precision (bf16). + precision: Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'), + 16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed'). Can be used on CPU, GPU, TPUs, HPUs or IPUs. - Default: ``32``. + Default: ``'32-true'``. max_epochs: Stop training once this number of epochs is reached. Disabled by default (None). If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``.