Skip to content

Commit

Permalink
Introduce new precision layout in fabric (#16767)
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock authored Feb 17, 2023
1 parent 3a354ac commit ac5fa03
Show file tree
Hide file tree
Showing 33 changed files with 214 additions and 135 deletions.
14 changes: 9 additions & 5 deletions docs/source-pytorch/fabric/api/fabric_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,23 +112,27 @@ Learn more about :ref:`distributed multi-node training on clusters <Fabric Clust
precision
=========

Fabric supports double precision (64), full precision (32), or half-precision (16) operation (including `bfloat16 <https://pytorch.org/docs/1.10.0/generated/torch.Tensor.bfloat16.html>`_).
Fabric supports double precision (64 bit), full precision (32 bit), or half-precision (16 bit) floating point operation (including `bfloat16 <https://pytorch.org/docs/1.10.0/generated/torch.Tensor.bfloat16.html>`_).
Half precision, or mixed precision, combines 32 and 16-bit floating points to reduce the memory footprint during model training.
Automatic mixed precision settings are denoted by a ``"-mixed"`` suffix, while settings that only work in the specified precision have a ``"-true"`` suffix.
This can result in improved performance, achieving significant speedups on modern GPUs.

.. code-block:: python
# Default used by the Fabric
fabric = Fabric(precision=32, devices=1)
fabric = Fabric(precision="32-true", devices=1)
# the same as:
fabric = Fabric(precision="32", devices=1)
# 16-bit (mixed) precision
fabric = Fabric(precision=16, devices=1)
fabric = Fabric(precision="16-mixed", devices=1)
# 16-bit bfloat precision
fabric = Fabric(precision="bf16", devices=1)
fabric = Fabric(precision="bf16-mixed", devices=1)
# 64-bit (double) precision
fabric = Fabric(precision=64, devices=1)
fabric = Fabric(precision="64-true", devices=1)
See also: :doc:`../fundamentals/precision`

Expand Down
10 changes: 7 additions & 3 deletions docs/source-pytorch/fabric/fundamentals/launch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,13 @@ This is essentially the same as running ``python path/to/your/script.py``, but i
--main-port, --main_port INTEGER
The main port to connect to the main
machine.
--precision [64|32|16|bf16] Double precision (``64``), full precision
(``32``), half precision (``16``) or
bfloat16 precision (``'bf16'``)
--precision [16-mixed|bf16-mixed|32-true|64-true|64|32|16|bf16]
Double precision (``64-true`` or ``64``),
full precision (``32-true`` or ``64``), half
precision (``16-mixed`` or ``16``) or
bfloat16 precision (``bf16-mixed`` or
``bf16``)
--help Show this message and exit.
Expand Down
32 changes: 19 additions & 13 deletions docs/source-pytorch/fabric/fundamentals/precision.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,35 @@ This is how you select the precision in Fabric:
from lightning.fabric import Fabric
# This is the default
fabric = Fabric(precision="32-true")
# Also FP32
fabric = Fabric(precision=32)
# FP16 mixed precision
fabric = Fabric(precision=16)
# FP32 as well
fabric = Fabric(precision="32")
# Precision values can also be set as a string
fabric = Fabric(precision="16")
# FP16 mixed precision
fabric = Fabric(precision="16-mixed)
# BFloat16 precision (Volta GPUs and later)
fabric = Fabric(precision="bf16")
fabric = Fabric(precision="bf16-mixed")
# Double precision
fabric = Fabric(precision="64-true")
# Or
fabric = Fabric(precision="64")
# Or
fabric = Fabric(precision=64)
The same values can also be set through the :doc:`command line interface <launch>`:
.. code-block:: bash
lightning run model train.py --precision=bf16
lightning run model train.py --precision=bf16-mixed
.. note::
Expand All @@ -70,14 +79,11 @@ This is how you enable FP16 in Fabric:
.. code-block:: python
# Select FP16 mixed precision
fabric = Fabric(precision=16)
# Or as a string
fabric = Fabric(precision="16")
fabric = Fabric(precision="16-mixed")
.. note::
When using TPUs, setting ``precision=16`` will enable bfloat16, the only supported half-precision type on TPUs.
When using TPUs, setting ``precision="16-mixed"`` will enable bfloat16 based mixed precision, the only supported half-precision type on TPUs.
----
Expand All @@ -94,7 +100,7 @@ For more information, see `this TPU performance blog post <https://cloud.google.
.. code-block:: python
# Select BF16 precision
fabric = Fabric(precision="bf16")
fabric = Fabric(precision="bf16-mixed")
Under the hood, we use `torch.autocast <https://pytorch.org/docs/stable/amp.html>`__ with the dtype set to ``bfloat16``, with no gradient scaling.
Expand All @@ -117,7 +123,7 @@ Fabric automatically casts the data type and operations in the ``forward`` of yo
.. code-block:: python
fabric = Fabric(precision="bf16")
fabric = Fabric(precision="bf16-mixed")
model = ...
optimizer = ...
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/fabric/guide/multi_node/cloud.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Launch multi-node training in the cloud
def run(self):
# Set up Fabric
# The `devices` and `num_nodes` gets set by Lightning automatically
fabric = L.Fabric(strategy="ddp", precision=16)
fabric = L.Fabric(strategy="ddp", precision="16-mixed")
# Your training code
model = ...
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Renamed `strategy='tpu_spawn'` to `strategy='xla'` and `strategy='tpu_spawn_debug'` to `strategy='xla_debug'` ([#16781](https://github.com/Lightning-AI/lightning/pull/16781))


- Changed arguments for precision settings (from [64|32|16|bf16] to ["64-true"|"32-true"|"16-mixed"|"bf16-mixed"]) ([#16767](https://github.com/Lightning-AI/lightning/pull/16767))

### Deprecated

-
Expand Down
11 changes: 6 additions & 5 deletions src/lightning/fabric/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from typing import Any, List, Optional

from lightning_utilities.core.imports import RequirementCache
from typing_extensions import get_args

from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS
from lightning.fabric.strategies import STRATEGY_REGISTRY
from lightning.fabric.utilities.device_parser import _parse_gpu_ids

Expand All @@ -28,7 +30,6 @@
_CLICK_AVAILABLE = RequirementCache("click")

_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu")
_SUPPORTED_PRECISION = ("64", "32", "16", "bf16")


def _get_supported_strategies() -> List[str]:
Expand Down Expand Up @@ -106,11 +107,11 @@ def _get_supported_strategies() -> List[str]:
)
@click.option(
"--precision",
type=click.Choice(_SUPPORTED_PRECISION),
default="32",
type=click.Choice(get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_STR_ALIAS)),
default="32-true",
help=(
"Double precision (``64``), full precision (``32``), half precision (``16``) or bfloat16 precision"
" (``'bf16'``)"
"Double precision (``64-true`` or ``64``), full precision (``32-true`` or ``64``), "
"half precision (``16-mixed`` or ``16``) or bfloat16 precision (``bf16-mixed`` or ``bf16``)"
),
)
@click.argument("script_args", nargs=-1, type=click.UNPROCESSED)
Expand Down
68 changes: 45 additions & 23 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@
)
from lightning.fabric.plugins.precision.double import DoublePrecision
from lightning.fabric.plugins.precision.fsdp import FSDPPrecision
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT, _PRECISION_INPUT_INT, _PRECISION_INPUT_STR
from lightning.fabric.plugins.precision.precision import (
_PRECISION_INPUT,
_PRECISION_INPUT_INT,
_PRECISION_INPUT_STR,
_PRECISION_INPUT_STR_ALIAS,
_PRECISION_INPUT_STR_ALIAS_CONVERSION,
)
from lightning.fabric.strategies import (
DeepSpeedStrategy,
ParallelStrategy,
Expand Down Expand Up @@ -98,7 +104,7 @@ def __init__(
strategy: Optional[Union[str, Strategy]] = None,
devices: Optional[Union[List[int], str, int]] = None,
num_nodes: int = 1,
precision: _PRECISION_INPUT = 32,
precision: _PRECISION_INPUT = "32-true",
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
) -> None:

Expand All @@ -107,7 +113,7 @@ def __init__(
strategy = self._argument_from_env("strategy", strategy, default=None)
devices = self._argument_from_env("devices", devices, default=None)
num_nodes = self._argument_from_env("num_nodes", num_nodes, default=1)
precision = self._argument_from_env("precision", precision, default=32)
precision = self._argument_from_env("precision", precision, default="32-true")

# 1. Parsing flags
# Get registered strategies, built-in accelerators and precision plugins
Expand All @@ -119,7 +125,7 @@ def __init__(
# For devices: Assign gpus, etc. to the accelerator flag and devices flag
self._strategy_flag: Optional[Union[Strategy, str]] = None
self._accelerator_flag: Optional[Union[Accelerator, str]] = None
self._precision_input: _PRECISION_INPUT_STR = "32"
self._precision_input: _PRECISION_INPUT_STR = "32-true"
self._precision_instance: Optional[Precision] = None
self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None
self._parallel_devices: List[Union[int, torch.device, str]] = []
Expand Down Expand Up @@ -220,10 +226,7 @@ def _check_config_and_set_final_flags(

self._accelerator_flag = accelerator

supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT)
if precision not in supported_precision:
raise ValueError(f"Precision {repr(precision)} is invalid. Allowed precision values: {supported_precision}")
self._precision_input = cast(_PRECISION_INPUT_STR, str(precision))
self._precision_input = _convert_precision_to_unified_args(precision)

if plugins:
plugins_flags_types: Dict[str, int] = Counter()
Expand Down Expand Up @@ -453,34 +456,34 @@ def _check_and_init_precision(self) -> Precision:
return self._precision_instance

if isinstance(self.accelerator, TPUAccelerator):
if self._precision_input == "32":
if self._precision_input == "32-true":
return TPUPrecision()
elif self._precision_input in ("16", "bf16"):
if self._precision_input == "16":
elif self._precision_input in ("16-mixed", "bf16-mixed"):
if self._precision_input == "16-mixed":
rank_zero_warn(
"You passed `Fabric(accelerator='tpu', precision=16)` but AMP"
" is not supported with TPUs. Using `precision='bf16'` instead."
"You passed `Fabric(accelerator='tpu', precision='16-mixed')` but AMP with fp16"
" is not supported with TPUs. Using `precision='bf16-mixed'` instead."
)
return TPUBf16Precision()
if isinstance(self.strategy, DeepSpeedStrategy):
return DeepSpeedPrecision(self._precision_input) # type: ignore

if self._precision_input == "32":
if self._precision_input == "32-true":
return Precision()
if self._precision_input == "64":
if self._precision_input == "64-true":
return DoublePrecision()

if self._precision_input == "16" and self._accelerator_flag == "cpu":
if self._precision_input == "16-mixed" and self._accelerator_flag == "cpu":
rank_zero_warn(
"You passed `Fabric(accelerator='cpu', precision=16)` but AMP is not supported on CPU."
" Using `precision='bf16'` instead."
"You passed `Fabric(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on "
"CPU. Using `precision='bf16-mixed'` instead."
)
self._precision_input = "bf16"
self._precision_input = "bf16-mixed"

if self._precision_input in ("16", "bf16"):
if self._precision_input in ("16-mixed", "bf16-mixed"):
rank_zero_info(
"Using 16-bit Automatic Mixed Precision (AMP)"
if self._precision_input == "16"
if self._precision_input == "16-mixed"
else "Using bfloat16 Automatic Mixed Precision (AMP)"
)
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
Expand All @@ -494,9 +497,9 @@ def _check_and_init_precision(self) -> Precision:
def _validate_precision_choice(self) -> None:
"""Validate the combination of choices for precision, and accelerator."""
if isinstance(self.accelerator, TPUAccelerator):
if self._precision_input == "64":
if self._precision_input == "64-true":
raise NotImplementedError(
"`Fabric(accelerator='tpu', precision=64)` is not implemented."
"`Fabric(accelerator='tpu', precision='64-true')` is not implemented."
" Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`"
" requesting this feature."
)
Expand Down Expand Up @@ -561,3 +564,22 @@ def _argument_from_env(name: str, current: Any, default: Any) -> Any:
if env_value is None:
return current
return env_value


def _convert_precision_to_unified_args(precision: _PRECISION_INPUT) -> _PRECISION_INPUT_STR:
supported_precision = (
get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + get_args(_PRECISION_INPUT_STR_ALIAS)
)
if precision not in supported_precision:
raise ValueError(f"Precision {repr(precision)} is invalid. Allowed precision values: {supported_precision}")

precision = str(precision) # convert int flags to str here to enable the legacy-conversion below

if precision in get_args(_PRECISION_INPUT_STR_ALIAS):
if str(precision)[:2] not in ("32", "64"):
rank_zero_warn(
f"{precision} is supported for historical reasons but its usage is discouraged. "
f"Please set your precision to {_PRECISION_INPUT_STR_ALIAS_CONVERSION[precision]} instead!"
)
precision = _PRECISION_INPUT_STR_ALIAS_CONVERSION[precision]
return cast(_PRECISION_INPUT_STR, precision)
6 changes: 3 additions & 3 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ class Fabric:
devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
The value applies per node.
num_nodes: Number of GPU nodes for distributed training.
precision: Double precision (``64``), full precision (``32``), half precision (``16``),
or bfloat16 precision (``"bf16"``).
precision: Double precision (``"64-true"``), full precision (``"32"``), half precision AMP (``"16-mixed"``),
or bfloat16 precision AMP (``"bf16-mixed"``).
plugins: One or several custom plugins
callbacks: A single callback or a list of callbacks. A callback can contain any arbitrary methods that
can be invoked through :meth:`~lightning.fabric.fabric.Fabric.call` by the user.
Expand All @@ -82,7 +82,7 @@ def __init__(
strategy: Optional[Union[str, Strategy]] = None,
devices: Optional[Union[List[int], str, int]] = None,
num_nodes: int = 1,
precision: _PRECISION_INPUT = 32,
precision: _PRECISION_INPUT = "32-true",
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
callbacks: Optional[Union[List[Any], Any]] = None,
loggers: Optional[Union[Logger, List[Logger]]] = None,
Expand Down
19 changes: 11 additions & 8 deletions src/lightning/fabric/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,24 @@ class MixedPrecision(Precision):
"""Plugin for Automatic Mixed Precision (AMP) training with ``torch.autocast``.
Args:
precision: Whether to use ``torch.float16`` (``16``) or ``torch.bfloat16`` (``'bf16'``).
precision: Whether to use ``torch.float16`` (``'16-mixed'``) or ``torch.bfloat16`` (``'bf16-mixed'``).
device: The device for ``torch.autocast``.
scaler: An optional :class:`torch.cuda.amp.GradScaler` to use.
"""

def __init__(
self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
self,
precision: Literal["16-mixed", "bf16-mixed"],
device: str,
scaler: Optional[torch.cuda.amp.GradScaler] = None,
) -> None:
self.precision = cast(Literal["16", "bf16"], str(precision))
if scaler is None and self.precision == "16":
self.precision = cast(Literal["16-mixed", "bf16-mixed"], str(precision))
if scaler is None and self.precision == "16-mixed":
with _patch_cuda_is_available():
# if possible, we defer CUDA initialization to support strategies that will attempt forks
scaler = torch.cuda.amp.GradScaler()
if scaler is not None and self.precision == "bf16":
raise ValueError(f"`precision='bf16'` does not use a scaler, found {scaler}.")
if scaler is not None and self.precision == "bf16-mixed":
raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
self.device = device
self.scaler = scaler

Expand All @@ -53,7 +56,7 @@ def forward_context(self) -> Generator[None, None, None]:
yield

def convert_input(self, data: Tensor) -> Tensor:
precision_to_type = {"bf16": torch.bfloat16, "16": torch.float16}
precision_to_type = {"bf16-mixed": torch.bfloat16, "16-mixed": torch.float16}
dst_type = precision_to_type[self.precision]
return _convert_fp_tensor(data, dst_type)

Expand Down Expand Up @@ -89,4 +92,4 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def _autocast_context_manager(self) -> torch.autocast:
# the dtype could be automatically inferred but we need to manually set it due to a bug upstream
# https://github.com/pytorch/pytorch/issues/67233
return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half)
return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16-mixed" else torch.half)
Loading

0 comments on commit ac5fa03

Please sign in to comment.