Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce new precision layout in PL #16783

Merged
merged 29 commits into from
Feb 17, 2023
Merged
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f93e79e
adapt precision in fabric
justusschock Feb 15, 2023
682131c
adapt fabric tests to new precision
justusschock Feb 15, 2023
ff870f2
update docs
justusschock Feb 16, 2023
02591b0
fix cli
justusschock Feb 16, 2023
ad697a1
Merge branch 'master' into 2.0/precision
justusschock Feb 16, 2023
a5e4848
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 16, 2023
457b6d1
changelog
justusschock Feb 16, 2023
0299a29
cli fixes
justusschock Feb 16, 2023
2e66197
fix tests and warnings
justusschock Feb 16, 2023
af09b84
update PL docs
justusschock Feb 16, 2023
747a08c
update examples
justusschock Feb 16, 2023
de2281d
update src code
justusschock Feb 16, 2023
b4e2e01
update tests
justusschock Feb 16, 2023
b553824
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 16, 2023
8321309
Merge branch 'master' into 2.0/precision_PL
justusschock Feb 17, 2023
6976db7
update with latest fabric changes
justusschock Feb 17, 2023
a730bcc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2023
e4eb724
fix tests
justusschock Feb 17, 2023
1ea5221
Merge branch 'master' into 2.0/precision_PL
justusschock Feb 17, 2023
84e24aa
update test to see results
justusschock Feb 17, 2023
3f8c319
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2023
cbb85ae
update tests
justusschock Feb 17, 2023
03482ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2023
4d19de4
update utils
justusschock Feb 17, 2023
14b06ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2023
01ef70e
update
justusschock Feb 17, 2023
69a46e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2023
4181d60
update
justusschock Feb 17, 2023
9cfbc24
revert pre-commit yet again
justusschock Feb 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
adapt precision in fabric
justusschock committed Feb 15, 2023
commit f93e79ee9f3221aac9eb4d862d00a8ae02f11434
8 changes: 4 additions & 4 deletions src/lightning/fabric/cli.py
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@
from lightning_utilities.core.imports import RequirementCache

from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR
from lightning.fabric.strategies import STRATEGY_REGISTRY
from lightning.fabric.utilities.device_parser import _parse_gpu_ids

@@ -28,7 +29,6 @@
_CLICK_AVAILABLE = RequirementCache("click")

_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu")
_SUPPORTED_PRECISION = ("64", "32", "16", "bf16")


def _get_supported_strategies() -> List[str]:
@@ -106,11 +106,11 @@ def _get_supported_strategies() -> List[str]:
)
@click.option(
"--precision",
type=click.Choice(_SUPPORTED_PRECISION),
type=click.Choice(_PRECISION_INPUT_STR),
default="32",
help=(
"Double precision (``64``), full precision (``32``), half precision (``16``) or bfloat16 precision"
" (``'bf16'``)"
"Double precision (``64-true``), full precision (``32-true``), half precision (``16-mixed``) or "
"bfloat16 precision (``'bf16-mixed'``)"
),
)
@click.argument("script_args", nargs=-1, type=click.UNPROCESSED)
56 changes: 37 additions & 19 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
@@ -42,7 +42,13 @@
)
from lightning.fabric.plugins.precision.double import DoublePrecision
from lightning.fabric.plugins.precision.fsdp import FSDPPrecision
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT, _PRECISION_INPUT_INT, _PRECISION_INPUT_STR
from lightning.fabric.plugins.precision.precision import (
_PRECISION_INPUT,
_PRECISION_INPUT_INT,
_PRECISION_INPUT_STR,
_PRECISION_INPUT_STR_LEGACY,
_PRECISION_INPUT_STR_LEGACY_CONVERSION,
)
from lightning.fabric.strategies import (
DeepSpeedStrategy,
ParallelStrategy,
@@ -107,7 +113,7 @@ def __init__(
strategy = self._argument_from_env("strategy", strategy, default=None)
devices = self._argument_from_env("devices", devices, default=None)
num_nodes = self._argument_from_env("num_nodes", num_nodes, default=1)
precision = self._argument_from_env("precision", precision, default=32)
precision = self._argument_from_env("precision", precision, default="32-true")

# 1. Parsing flags
# Get registered strategies, built-in accelerators and precision plugins
@@ -119,7 +125,7 @@ def __init__(
# For devices: Assign gpus, etc. to the accelerator flag and devices flag
self._strategy_flag: Optional[Union[Strategy, str]] = None
self._accelerator_flag: Optional[Union[Accelerator, str]] = None
self._precision_input: _PRECISION_INPUT_STR = "32"
self._precision_input: _PRECISION_INPUT_STR = "32-true"
self._precision_instance: Optional[Precision] = None
self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None
self._parallel_devices: List[Union[int, torch.device, str]] = []
@@ -220,10 +226,22 @@ def _check_config_and_set_final_flags(

self._accelerator_flag = accelerator

supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT)
supported_precision = (
get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + get_args(_PRECISION_INPUT_STR_LEGACY)
)
if precision not in supported_precision:
raise ValueError(f"Precision {repr(precision)} is invalid. Allowed precision values: {supported_precision}")
self._precision_input = cast(_PRECISION_INPUT_STR, str(precision))

precision = str(precision) # convert int flags to str here to enable the legacy-conversion below

if precision in get_args(_PRECISION_INPUT_STR_LEGACY):
rank_zero_warn(
f"{precision} is supported for historical reasons but its usage is discouraged. "
"Please set your precision to {_PRECISION_INPUT_STR_LEGACY_CONVERSION[precision]} instead!"
)
precision = _PRECISION_INPUT_STR_LEGACY_CONVERSION[precision]

self._precision_input = cast(_PRECISION_INPUT_STR, precision)

if plugins:
plugins_flags_types: Dict[str, int] = Counter()
@@ -453,34 +471,34 @@ def _check_and_init_precision(self) -> Precision:
return self._precision_instance

if isinstance(self.accelerator, TPUAccelerator):
if self._precision_input == "32":
if self._precision_input == "32-true":
return TPUPrecision()
elif self._precision_input in ("16", "bf16"):
if self._precision_input == "16":
elif self._precision_input in ("16-mixed", "bf16-mixed"):
if self._precision_input == "16-mixed":
rank_zero_warn(
"You passed `Fabric(accelerator='tpu', precision=16)` but AMP"
" is not supported with TPUs. Using `precision='bf16'` instead."
"You passed `Fabric(accelerator='tpu', precision='16-mixed')` but AMP with fp16"
" is not supported with TPUs. Using `precision='bf16-mixed'` instead."
)
return TPUBf16Precision()
if isinstance(self.strategy, DeepSpeedStrategy):
return DeepSpeedPrecision(self._precision_input) # type: ignore

if self._precision_input == "32":
if self._precision_input == "32-true":
return Precision()
if self._precision_input == "64":
if self._precision_input == "64-true":
return DoublePrecision()

if self._precision_input == "16" and self._accelerator_flag == "cpu":
if self._precision_input == "16-mixed" and self._accelerator_flag == "cpu":
rank_zero_warn(
"You passed `Fabric(accelerator='cpu', precision=16)` but AMP is not supported on CPU."
" Using `precision='bf16'` instead."
"You passed `Fabric(accelerator='cpu', precision='16-mixed')` but AMP is not supported on CPU with "
"fp16. Using `precision='bf16-mixed'` instead."
)
self._precision_input = "bf16"
self._precision_input = "bf16-mixed"

if self._precision_input in ("16", "bf16"):
if self._precision_input in ("16-mixed", "bf16-mixed"):
rank_zero_info(
"Using 16-bit Automatic Mixed Precision (AMP)"
if self._precision_input == "16"
if self._precision_input == "16-mixed"
else "Using bfloat16 Automatic Mixed Precision (AMP)"
)
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
@@ -494,7 +512,7 @@ def _check_and_init_precision(self) -> Precision:
def _validate_precision_choice(self) -> None:
"""Validate the combination of choices for precision, and accelerator."""
if isinstance(self.accelerator, TPUAccelerator):
if self._precision_input == "64":
if self._precision_input == "64-true":
raise NotImplementedError(
"`Fabric(accelerator='tpu', precision=64)` is not implemented."
" Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`"
6 changes: 3 additions & 3 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
@@ -67,8 +67,8 @@ class Fabric:
devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
The value applies per node.
num_nodes: Number of GPU nodes for distributed training.
precision: Double precision (``64``), full precision (``32``), half precision (``16``),
or bfloat16 precision (``"bf16"``).
precision: Double precision (``"64-true"``), full precision (``"32"``), half precision AMP (``"16-mixed"``),
or bfloat16 precision AMP (``"bf16-mixed"``).
plugins: One or several custom plugins
callbacks: A single callback or a list of callbacks. A callback can contain any arbitrary methods that
can be invoked through :meth:`~lightning.fabric.fabric.Fabric.call` by the user.
@@ -82,7 +82,7 @@ def __init__(
strategy: Optional[Union[str, Strategy]] = None,
devices: Optional[Union[List[int], str, int]] = None,
num_nodes: int = 1,
precision: _PRECISION_INPUT = 32,
precision: _PRECISION_INPUT = "32-true",
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
callbacks: Optional[Union[List[Any], Any]] = None,
loggers: Optional[Union[Logger, List[Logger]]] = None,
17 changes: 10 additions & 7 deletions src/lightning/fabric/plugins/precision/amp.py
Original file line number Diff line number Diff line change
@@ -29,21 +29,24 @@ class MixedPrecision(Precision):
"""Plugin for Automatic Mixed Precision (AMP) training with ``torch.autocast``.

Args:
precision: Whether to use ``torch.float16`` (``16``) or ``torch.bfloat16`` (``'bf16'``).
precision: Whether to use ``torch.float16`` (``'16-mixed'``) or ``torch.bfloat16`` (``'bf16-mixed'``).
device: The device for ``torch.autocast``.
scaler: An optional :class:`torch.cuda.amp.GradScaler` to use.
"""

def __init__(
self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
self,
precision: Literal["16-mixed", "bf16-mixed"],
device: str,
scaler: Optional[torch.cuda.amp.GradScaler] = None,
) -> None:
self.precision = cast(Literal["16", "bf16"], str(precision))
if scaler is None and self.precision == "16":
self.precision = cast(Literal["16-mixed", "bf16-mixed"], str(precision))
if scaler is None and self.precision == "16-mixed":
with _patch_cuda_is_available():
# if possible, we defer CUDA initialization to support strategies that will attempt forks
scaler = torch.cuda.amp.GradScaler()
if scaler is not None and self.precision == "bf16":
raise ValueError(f"`precision='bf16'` does not use a scaler, found {scaler}.")
if scaler is not None and self.precision == "bf16-mixed":
raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
self.device = device
self.scaler = scaler

@@ -53,7 +56,7 @@ def forward_context(self) -> Generator[None, None, None]:
yield

def convert_input(self, data: Tensor) -> Tensor:
precision_to_type = {"bf16": torch.bfloat16, "16": torch.float16}
precision_to_type = {"bf16-mixed": torch.bfloat16, "16-mixed": torch.float16}
dst_type = precision_to_type[self.precision]
return _convert_fp_tensor(data, dst_type)

14 changes: 6 additions & 8 deletions src/lightning/fabric/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, cast, Literal, TYPE_CHECKING, Union
from typing import Any, cast, Literal, TYPE_CHECKING

import torch
from torch import Tensor
@@ -27,33 +27,31 @@
if _DEEPSPEED_AVAILABLE: # type: ignore[has-type]
import deepspeed

_PRECISION_INPUT_INT = Literal[32, 16]
_PRECISION_INPUT_STR = Literal["32", "16", "bf16"]
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR]
_PRECISION_INPUT = Literal["32-true", "16-mixed", "bf16-mixed"]


class DeepSpeedPrecision(Precision):
"""Precision plugin for DeepSpeed integration.

Args:
precision: Full precision (32), half precision (16) or bfloat16 precision (bf16).
precision: Full precision (32-true), half precision (16-mixed) or bfloat16 precision (bf16-mixed).

Raises:
ValueError:
If unsupported ``precision`` is provided.
"""

def __init__(self, precision: _PRECISION_INPUT) -> None:
supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT)
supported_precision = get_args(_PRECISION_INPUT)
if precision not in supported_precision:
raise ValueError(
f"`precision={precision!r})` is not supported in DeepSpeed."
f" `precision` must be one of: {supported_precision}."
)
self.precision = cast(_PRECISION_INPUT_STR, str(precision))
self.precision = cast(_PRECISION_INPUT, precision)

def convert_input(self, data: Tensor) -> Tensor:
precision_to_type = {"bf16": torch.bfloat16, "16": torch.float16, "32": torch.float32}
precision_to_type = {"bf16-mixed": torch.bfloat16, "16-mixed": torch.float16, "32-true": torch.float32}
dst_type = precision_to_type[self.precision]
return _convert_fp_tensor(data, dst_type)

2 changes: 1 addition & 1 deletion src/lightning/fabric/plugins/precision/double.py
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@
class DoublePrecision(Precision):
"""Plugin for training with double (``torch.float64``) precision."""

precision: Literal["64"] = "64"
precision: Literal["64-true"] = "64-true"

def convert_module(self, module: Module) -> Module:
return module.double()
8 changes: 4 additions & 4 deletions src/lightning/fabric/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@ class FSDPPrecision(MixedPrecision):
"""AMP for Fully Sharded Data Parallel training."""

def __init__(
self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None
self, precision: Literal["16-mixed", "bf16-mixed"], device: str, scaler: Optional["ShardedGradScaler"] = None
) -> None:
if not _TORCH_GREATER_EQUAL_1_12:
raise NotImplementedError("`FSDPPrecision` is supported from PyTorch v1.12.0 onwards.")
@@ -37,16 +37,16 @@ def __init__(
super().__init__(
precision=precision,
device=device,
scaler=(ShardedGradScaler() if scaler is None and str(precision) == "16" else None),
scaler=(ShardedGradScaler() if scaler is None and precision == "16-mixed" else None),
)

@property
def mixed_precision_config(self) -> "TorchMixedPrecision":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision

if self.precision == "16":
if self.precision == "16-mixed":
dtype = torch.float16
elif self.precision == "bf16":
elif self.precision == "bf16-mixed":
dtype = torch.bfloat16
else:
raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.")
8 changes: 5 additions & 3 deletions src/lightning/fabric/plugins/precision/precision.py
Original file line number Diff line number Diff line change
@@ -23,8 +23,10 @@
from lightning.fabric.utilities.types import _PARAMETERS, Optimizable

_PRECISION_INPUT_INT = Literal[64, 32, 16]
_PRECISION_INPUT_STR = Literal["64", "32", "16", "bf16"]
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR]
_PRECISION_INPUT_STR_LEGACY_CONVERSION = {"64": "64-true", "32": "32-true", "16": "16-mixed", "bf16": "bf16-mixed"}
_PRECISION_INPUT_STR_LEGACY = Literal["64", "32", "16", "bf16"]
_PRECISION_INPUT_STR = Literal["16-mixed", "bf16-mixed", "32-true", "64-true"]
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_LEGACY]


class Precision:
@@ -33,7 +35,7 @@ class Precision:
The class attribute precision must be overwritten in child classes. The default value reflects fp32 training.
"""

precision: _PRECISION_INPUT_STR = "32"
precision: _PRECISION_INPUT_STR = "32-true"

def convert_module(self, module: Module) -> Module:
"""Convert the module parameters to the precision type this plugin handles.
2 changes: 1 addition & 1 deletion src/lightning/fabric/plugins/precision/tpu_bf16.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@
class TPUBf16Precision(TPUPrecision):
"""Plugin that enables bfloats on TPUs."""

precision: Literal["bf16"] = "bf16"
precision: Literal["bf16-mixed"] = "bf16-mixed"

def __init__(self) -> None:
super().__init__()
12 changes: 6 additions & 6 deletions src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
@@ -105,8 +105,8 @@ def __init__(

Arguments:

zero_optimization: Enable ZeRO optimization. This is compatible with either ``precision=16`` or
``precision="bf16"``.
zero_optimization: Enable ZeRO optimization. This is compatible with either ``precision="16-mixed"`` or
``precision="bf16-mixed"``.

stage: Different stages of the ZeRO Optimizer. 0 is disabled,
1 is optimizer state partitioning, 2 is optimizer+gradient state partitioning,
@@ -350,9 +350,9 @@ def module_sharded_context(self) -> Generator[None, None, None]:
if self.zero_stage_3:
assert self._config_initialized

if self.precision.precision == "16":
if self.precision.precision == "16-mixed":
dtype = torch.float16
elif self.precision.precision == "bf16":
elif self.precision.precision == "bf16-mixed":
dtype = torch.bfloat16
else:
dtype = torch.float32
@@ -604,7 +604,7 @@ def _format_config(self) -> None:

def _format_precision_config(self) -> None:
assert isinstance(self.config, dict)
if self.precision.precision == "16":
if self.precision.precision == "16-mixed":
if "fp16" not in self.config:
# FP16 is a DeepSpeed standalone AMP implementation
rank_zero_info("Enabling DeepSpeed FP16.")
@@ -616,7 +616,7 @@ def _format_precision_config(self) -> None:
"hysteresis": self.hysteresis,
"min_loss_scale": self.min_loss_scale,
}
elif "bf16" not in self.config and self.precision.precision == "bf16":
elif "bf16" not in self.config and self.precision.precision == "bf16-mixed":
rank_zero_info("Enabling DeepSpeed BF16.")
self.config["bf16"] = {"enabled": True}

5 changes: 3 additions & 2 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
@@ -76,8 +76,9 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
backward_prefetch: This is an experimental feature that is subject to change in the near future. It allows
users to enable two different backward prefetching algorithms to help backward communication and
computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16
if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later.
mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision="16-mixed"`` or
BF16 if ``precision="bf16-mixed"`` unless a config is passed in.
This is only available in PyTorch 1.12 and later.
activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation
checkpointing. This is typically your transformer block (including attention + feed-forward).
Enabling this can free up a significant amount of memory at the cost of speed since activations in