Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce new precision layout in fabric #16767

Merged
merged 28 commits into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
f93e79e
adapt precision in fabric
justusschock Feb 15, 2023
682131c
adapt fabric tests to new precision
justusschock Feb 15, 2023
ff870f2
update docs
justusschock Feb 16, 2023
02591b0
fix cli
justusschock Feb 16, 2023
ad697a1
Merge branch 'master' into 2.0/precision
justusschock Feb 16, 2023
a5e4848
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 16, 2023
457b6d1
changelog
justusschock Feb 16, 2023
0299a29
cli fixes
justusschock Feb 16, 2023
2e66197
fix tests and warnings
justusschock Feb 16, 2023
9a9faf4
update with discussion result
justusschock Feb 16, 2023
4d6c263
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 16, 2023
d4a212a
ignore typing for now
justusschock Feb 16, 2023
fe5b84d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 16, 2023
5ff11f1
update typing
justusschock Feb 16, 2023
7943fea
Merge branch 'master' into 2.0/precision
justusschock Feb 16, 2023
54e1c4b
Update docs/source-pytorch/fabric/api/fabric_args.rst
justusschock Feb 16, 2023
20444b2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 16, 2023
6e7d6e8
Update tests/tests_fabric/test_connector.py
justusschock Feb 16, 2023
b348c55
Update src/lightning/fabric/connector.py
justusschock Feb 16, 2023
f60a370
reviewer comments
justusschock Feb 17, 2023
56c55bb
Merge branch 'master' into 2.0/precision
justusschock Feb 17, 2023
47eb750
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2023
0db2aa9
fix merge error
justusschock Feb 17, 2023
b6b7f58
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2023
0470bee
fu precommit
justusschock Feb 17, 2023
9677f2a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2023
4f0bf8f
mypy
justusschock Feb 17, 2023
3bdd879
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions docs/source-pytorch/fabric/api/fabric_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,23 +112,24 @@ Learn more about :ref:`distributed multi-node training on clusters <Fabric Clust
precision
=========

Fabric supports double precision (64), full precision (32), or half-precision (16) operation (including `bfloat16 <https://pytorch.org/docs/1.10.0/generated/torch.Tensor.bfloat16.html>`_).
Fabric supports double precision (64 bit), full precision (32 bit), or half-precision (16 bit) floating point operation (including `bfloat16 <https://pytorch.org/docs/1.10.0/generated/torch.Tensor.bfloat16.html>`_).
Half precision, or mixed precision, combines 32 and 16-bit floating points to reduce the memory footprint during model training.
Automatic mixed precision settings are denoted by a ``"-mixed"`` suffix, while settings that only work in the specified precision have a ``"-true"`` suffix.
This can result in improved performance, achieving significant speedups on modern GPUs.

.. code-block:: python

# Default used by the Fabric
fabric = Fabric(precision=32, devices=1)
fabric = Fabric(precision="32-true", devices=1)
justusschock marked this conversation as resolved.
Show resolved Hide resolved

# 16-bit (mixed) precision
fabric = Fabric(precision=16, devices=1)
fabric = Fabric(precision="16-mixed", devices=1)

# 16-bit bfloat precision
fabric = Fabric(precision="bf16", devices=1)
fabric = Fabric(precision="bf16-mixed", devices=1)

# 64-bit (double) precision
fabric = Fabric(precision=64, devices=1)
fabric = Fabric(precision="64-true", devices=1)

See also: :doc:`../fundamentals/precision`

Expand Down
10 changes: 7 additions & 3 deletions docs/source-pytorch/fabric/fundamentals/launch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,13 @@ This is essentially the same as running ``python path/to/your/script.py``, but i
--main-port, --main_port INTEGER
The main port to connect to the main
machine.
--precision [64|32|16|bf16] Double precision (``64``), full precision
(``32``), half precision (``16``) or
bfloat16 precision (``'bf16'``)
--precision [16-mixed|bf16-mixed|32-true|64-true|64|32|16|bf16]
Double precision (``64-true`` or ``64``),
full precision (``32-true`` or ``64``), half
precision (``16-mixed`` or ``16``) or
bfloat16 precision (``bf16-mixed`` or
``bf16``)

--help Show this message and exit.


Expand Down
32 changes: 19 additions & 13 deletions docs/source-pytorch/fabric/fundamentals/precision.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,35 @@ This is how you select the precision in Fabric:
from lightning.fabric import Fabric

# This is the default
fabric = Fabric(precision="32-true")

# Also FP32
fabric = Fabric(precision=32)

# FP16 mixed precision
fabric = Fabric(precision=16)
# FP32 as well
fabric = Fabric(precision="32")

# Precision values can also be set as a string
fabric = Fabric(precision="16")
# FP16 mixed precision
fabric = Fabric(precision="16-mixed)

# BFloat16 precision (Volta GPUs and later)
fabric = Fabric(precision="bf16")
fabric = Fabric(precision="bf16-mixed")

# Double precision
fabric = Fabric(precision="64-true")

# Or
fabric = Fabric(precision="64")

# Or
fabric = Fabric(precision=64)


The same values can also be set through the :doc:`command line interface <launch>`:

.. code-block:: bash

lightning run model train.py --precision=bf16
lightning run model train.py --precision=bf16-mixed


.. note::
Expand All @@ -70,14 +79,11 @@ This is how you enable FP16 in Fabric:
.. code-block:: python

# Select FP16 mixed precision
fabric = Fabric(precision=16)

# Or as a string
fabric = Fabric(precision="16")
fabric = Fabric(precision="16-mixed")

.. note::

When using TPUs, setting ``precision=16`` will enable bfloat16, the only supported half-precision type on TPUs.
When using TPUs, setting ``precision="16-mixed"`` will enable bfloat16 based mixed precision, the only supported half-precision type on TPUs.


----
Expand All @@ -94,7 +100,7 @@ For more information, see `this TPU performance blog post <https://cloud.google.
.. code-block:: python

# Select BF16 precision
fabric = Fabric(precision="bf16")
fabric = Fabric(precision="bf16-mixed")


Under the hood, we use `torch.autocast <https://pytorch.org/docs/stable/amp.html>`__ with the dtype set to ``bfloat16``, with no gradient scaling.
Expand All @@ -117,7 +123,7 @@ Fabric automatically casts the data type and operations in the ``forward`` of yo

.. code-block:: python

fabric = Fabric(precision="bf16")
fabric = Fabric(precision="bf16-mixed")

model = ...
optimizer = ...
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/fabric/guide/multi_node/cloud.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Launch multi-node training in the cloud
def run(self):
# Set up Fabric
# The `devices` and `num_nodes` gets set by Lightning automatically
fabric = L.Fabric(strategy="ddp", precision=16)
fabric = L.Fabric(strategy="ddp", precision="16-mixed")

# Your training code
model = ...
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enabled all shorthand strategy names that can be supported in the CLI ([#16485](https://github.com/Lightning-AI/lightning/pull/16485))


- Changed arguments for precision settings (from [64|32|16|bf16] to ["64-true"|"32-true"|"16-mixed"|"bf16-mixed"]) ([#16767](https://github.com/Lightning-AI/lightning/pull/16767))

### Deprecated

-
Expand Down
14 changes: 9 additions & 5 deletions src/lightning/fabric/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@
from typing import Any, List, Optional

from lightning_utilities.core.imports import RequirementCache
from typing_extensions import get_args

from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
from lightning.fabric.plugins.precision.precision import (
_PRECISION_INPUT_STR,
_PRECISION_INPUT_STR_LEGACY,
)
from lightning.fabric.strategies import STRATEGY_REGISTRY
from lightning.fabric.utilities.device_parser import _parse_gpu_ids

Expand All @@ -28,7 +33,6 @@
_CLICK_AVAILABLE = RequirementCache("click")

_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu")
_SUPPORTED_PRECISION = ("64", "32", "16", "bf16")


def _get_supported_strategies() -> List[str]:
Expand Down Expand Up @@ -106,11 +110,11 @@ def _get_supported_strategies() -> List[str]:
)
@click.option(
"--precision",
type=click.Choice(_SUPPORTED_PRECISION),
default="32",
type=click.Choice(get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_STR_LEGACY)),
default="32-true",
help=(
"Double precision (``64``), full precision (``32``), half precision (``16``) or bfloat16 precision"
" (``'bf16'``)"
"Double precision (``64-true`` or ``64``), full precision (``32-true`` or ``64``), "
"half precision (``16-mixed`` or ``16``) or bfloat16 precision (``bf16-mixed`` or ``bf16``)"
),
)
@click.argument("script_args", nargs=-1, type=click.UNPROCESSED)
Expand Down
61 changes: 40 additions & 21 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@
)
from lightning.fabric.plugins.precision.double import DoublePrecision
from lightning.fabric.plugins.precision.fsdp import FSDPPrecision
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT, _PRECISION_INPUT_INT, _PRECISION_INPUT_STR
from lightning.fabric.plugins.precision.precision import (
_PRECISION_INPUT,
_PRECISION_INPUT_INT,
_PRECISION_INPUT_STR,
_PRECISION_INPUT_STR_LEGACY,
_PRECISION_INPUT_STR_LEGACY_CONVERSION,
)
from lightning.fabric.strategies import (
DeepSpeedStrategy,
ParallelStrategy,
Expand Down Expand Up @@ -98,7 +104,7 @@ def __init__(
strategy: Optional[Union[str, Strategy]] = None,
devices: Optional[Union[List[int], str, int]] = None,
num_nodes: int = 1,
precision: _PRECISION_INPUT = 32,
precision: _PRECISION_INPUT = "32-true",
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
) -> None:

Expand All @@ -107,7 +113,7 @@ def __init__(
strategy = self._argument_from_env("strategy", strategy, default=None)
devices = self._argument_from_env("devices", devices, default=None)
num_nodes = self._argument_from_env("num_nodes", num_nodes, default=1)
precision = self._argument_from_env("precision", precision, default=32)
precision = self._argument_from_env("precision", precision, default="32-true")

# 1. Parsing flags
# Get registered strategies, built-in accelerators and precision plugins
Expand All @@ -119,7 +125,7 @@ def __init__(
# For devices: Assign gpus, etc. to the accelerator flag and devices flag
self._strategy_flag: Optional[Union[Strategy, str]] = None
self._accelerator_flag: Optional[Union[Accelerator, str]] = None
self._precision_input: _PRECISION_INPUT_STR = "32"
self._precision_input: _PRECISION_INPUT_STR = "32-true"
self._precision_instance: Optional[Precision] = None
self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None
self._parallel_devices: List[Union[int, torch.device, str]] = []
Expand Down Expand Up @@ -220,10 +226,23 @@ def _check_config_and_set_final_flags(

self._accelerator_flag = accelerator

supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT)
supported_precision = (
get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + get_args(_PRECISION_INPUT_STR_LEGACY)
)
if precision not in supported_precision:
raise ValueError(f"Precision {repr(precision)} is invalid. Allowed precision values: {supported_precision}")
self._precision_input = cast(_PRECISION_INPUT_STR, str(precision))

precision = str(precision) # convert int flags to str here to enable the legacy-conversion below

if precision in get_args(_PRECISION_INPUT_STR_LEGACY):
if str(precision)[:2] not in ("32", "64"):
justusschock marked this conversation as resolved.
Show resolved Hide resolved
carmocca marked this conversation as resolved.
Show resolved Hide resolved
rank_zero_warn(
f"{precision} is supported for historical reasons but its usage is discouraged. "
f"Please set your precision to {_PRECISION_INPUT_STR_LEGACY_CONVERSION[precision]} instead!"
)
precision = _PRECISION_INPUT_STR_LEGACY_CONVERSION[precision]

self._precision_input = cast(_PRECISION_INPUT_STR, precision)
justusschock marked this conversation as resolved.
Show resolved Hide resolved

if plugins:
plugins_flags_types: Dict[str, int] = Counter()
Expand Down Expand Up @@ -453,34 +472,34 @@ def _check_and_init_precision(self) -> Precision:
return self._precision_instance

if isinstance(self.accelerator, TPUAccelerator):
if self._precision_input == "32":
if self._precision_input == "32-true":
return TPUPrecision()
elif self._precision_input in ("16", "bf16"):
if self._precision_input == "16":
elif self._precision_input in ("16-mixed", "bf16-mixed"):
if self._precision_input == "16-mixed":
rank_zero_warn(
"You passed `Fabric(accelerator='tpu', precision=16)` but AMP"
" is not supported with TPUs. Using `precision='bf16'` instead."
"You passed `Fabric(accelerator='tpu', precision='16-mixed')` but AMP with fp16"
" is not supported with TPUs. Using `precision='bf16-mixed'` instead."
)
return TPUBf16Precision()
if isinstance(self.strategy, DeepSpeedStrategy):
return DeepSpeedPrecision(self._precision_input) # type: ignore

if self._precision_input == "32":
if self._precision_input == "32-true":
return Precision()
if self._precision_input == "64":
if self._precision_input == "64-true":
return DoublePrecision()

if self._precision_input == "16" and self._accelerator_flag == "cpu":
if self._precision_input == "16-mixed" and self._accelerator_flag == "cpu":
rank_zero_warn(
"You passed `Fabric(accelerator='cpu', precision=16)` but AMP is not supported on CPU."
" Using `precision='bf16'` instead."
"You passed `Fabric(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on "
"CPU. Using `precision='bf16-mixed'` instead."
)
self._precision_input = "bf16"
self._precision_input = "bf16-mixed"

if self._precision_input in ("16", "bf16"):
if self._precision_input in ("16-mixed", "bf16-mixed"):
rank_zero_info(
"Using 16-bit Automatic Mixed Precision (AMP)"
if self._precision_input == "16"
if self._precision_input == "16-mixed"
else "Using bfloat16 Automatic Mixed Precision (AMP)"
)
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
Expand All @@ -494,9 +513,9 @@ def _check_and_init_precision(self) -> Precision:
def _validate_precision_choice(self) -> None:
"""Validate the combination of choices for precision, and accelerator."""
if isinstance(self.accelerator, TPUAccelerator):
if self._precision_input == "64":
if self._precision_input == "64-true":
raise NotImplementedError(
"`Fabric(accelerator='tpu', precision=64)` is not implemented."
"`Fabric(accelerator='tpu', precision='64-true')` is not implemented."
" Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`"
" requesting this feature."
)
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ class Fabric:
devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
The value applies per node.
num_nodes: Number of GPU nodes for distributed training.
precision: Double precision (``64``), full precision (``32``), half precision (``16``),
or bfloat16 precision (``"bf16"``).
precision: Double precision (``"64-true"``), full precision (``"32"``), half precision AMP (``"16-mixed"``),
or bfloat16 precision AMP (``"bf16-mixed"``).
plugins: One or several custom plugins
callbacks: A single callback or a list of callbacks. A callback can contain any arbitrary methods that
can be invoked through :meth:`~lightning.fabric.fabric.Fabric.call` by the user.
Expand All @@ -82,7 +82,7 @@ def __init__(
strategy: Optional[Union[str, Strategy]] = None,
devices: Optional[Union[List[int], str, int]] = None,
num_nodes: int = 1,
precision: _PRECISION_INPUT = 32,
precision: _PRECISION_INPUT = "32-true",
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
callbacks: Optional[Union[List[Any], Any]] = None,
loggers: Optional[Union[Logger, List[Logger]]] = None,
Expand Down
19 changes: 11 additions & 8 deletions src/lightning/fabric/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,24 @@ class MixedPrecision(Precision):
"""Plugin for Automatic Mixed Precision (AMP) training with ``torch.autocast``.

Args:
precision: Whether to use ``torch.float16`` (``16``) or ``torch.bfloat16`` (``'bf16'``).
precision: Whether to use ``torch.float16`` (``'16-mixed'``) or ``torch.bfloat16`` (``'bf16-mixed'``).
device: The device for ``torch.autocast``.
scaler: An optional :class:`torch.cuda.amp.GradScaler` to use.
"""

def __init__(
self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
self,
precision: Literal["16-mixed", "bf16-mixed"],
carmocca marked this conversation as resolved.
Show resolved Hide resolved
device: str,
scaler: Optional[torch.cuda.amp.GradScaler] = None,
) -> None:
self.precision = cast(Literal["16", "bf16"], str(precision))
if scaler is None and self.precision == "16":
self.precision = cast(Literal["16-mixed", "bf16-mixed"], str(precision))
if scaler is None and self.precision == "16-mixed":
with _patch_cuda_is_available():
# if possible, we defer CUDA initialization to support strategies that will attempt forks
scaler = torch.cuda.amp.GradScaler()
if scaler is not None and self.precision == "bf16":
raise ValueError(f"`precision='bf16'` does not use a scaler, found {scaler}.")
if scaler is not None and self.precision == "bf16-mixed":
raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
self.device = device
self.scaler = scaler

Expand All @@ -53,7 +56,7 @@ def forward_context(self) -> Generator[None, None, None]:
yield

def convert_input(self, data: Tensor) -> Tensor:
precision_to_type = {"bf16": torch.bfloat16, "16": torch.float16}
precision_to_type = {"bf16-mixed": torch.bfloat16, "16-mixed": torch.float16}
dst_type = precision_to_type[self.precision]
return _convert_fp_tensor(data, dst_type)

Expand Down Expand Up @@ -89,4 +92,4 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def _autocast_context_manager(self) -> torch.autocast:
# the dtype could be automatically inferred but we need to manually set it due to a bug upstream
# https://github.com/pytorch/pytorch/issues/67233
return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half)
return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16-mixed" else torch.half)
Loading