Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce new precision layout in PL #16783

Merged
merged 29 commits into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f93e79e
adapt precision in fabric
justusschock Feb 15, 2023
682131c
adapt fabric tests to new precision
justusschock Feb 15, 2023
ff870f2
update docs
justusschock Feb 16, 2023
02591b0
fix cli
justusschock Feb 16, 2023
ad697a1
Merge branch 'master' into 2.0/precision
justusschock Feb 16, 2023
a5e4848
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 16, 2023
457b6d1
changelog
justusschock Feb 16, 2023
0299a29
cli fixes
justusschock Feb 16, 2023
2e66197
fix tests and warnings
justusschock Feb 16, 2023
af09b84
update PL docs
justusschock Feb 16, 2023
747a08c
update examples
justusschock Feb 16, 2023
de2281d
update src code
justusschock Feb 16, 2023
b4e2e01
update tests
justusschock Feb 16, 2023
b553824
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 16, 2023
8321309
Merge branch 'master' into 2.0/precision_PL
justusschock Feb 17, 2023
6976db7
update with latest fabric changes
justusschock Feb 17, 2023
a730bcc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2023
e4eb724
fix tests
justusschock Feb 17, 2023
1ea5221
Merge branch 'master' into 2.0/precision_PL
justusschock Feb 17, 2023
84e24aa
update test to see results
justusschock Feb 17, 2023
3f8c319
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2023
cbb85ae
update tests
justusschock Feb 17, 2023
03482ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2023
4d19de4
update utils
justusschock Feb 17, 2023
14b06ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2023
01ef70e
update
justusschock Feb 17, 2023
69a46e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2023
4181d60
update
justusschock Feb 17, 2023
9cfbc24
revert pre-commit yet again
justusschock Feb 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions docs/source-pytorch/common/precision_basic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ Higher precision, such as the 64-bit floating-point, can be used for highly sens
16-bit Precision
****************

Use 16-bit precision to cut your memory consumption in half so that you can train and deploy larger models. If your GPUs are [`Tensor Core <https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html>`_] GPUs, you can also get a ~3x speed improvement. Half precision can sometimes lead to unstable training.
Use 16-bit mixed precision to lower your memory consumption by up to half so that you can train and deploy larger models. If your GPUs are [`Tensor Core <https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html>`_] GPUs, you can also get a ~3x speed improvement. Half precision can sometimes lead to unstable training.

.. code::

Trainer(precision=16)
Trainer(precision='16-mixed')

----

Expand All @@ -36,6 +36,12 @@ Use 16-bit precision to cut your memory consumption in half so that you can trai

.. testcode::

Trainer(precision='32-true')

# or
Trainer(precision='32')

# or
Trainer(precision=32)

----
Expand All @@ -48,6 +54,12 @@ For certain scientific computations, 64-bit precision enables more accurate mode

.. testcode::

Trainer(precision='64-true')

# or
Trainer(precision='64')

# or
Trainer(precision=64)

.. note::
Expand All @@ -70,22 +82,22 @@ Precision support by accelerator
- GPU
- TPU
- IPU
* - 16
* - 16 Mixed
- No
- Yes
- No
- Yes
* - BFloat16
* - BFloat16 Mixed
- Yes
- Yes
- Yes
- No
* - 32
* - 32 True
- Yes
- Yes
- Yes
- Yes
* - 64
* - 64 True
- Yes
- Yes
- No
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/common/precision_expert.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ You can also customize and pass your own Precision Plugin by subclassing the :cl
.. code-block:: python

class CustomPrecisionPlugin(PrecisionPlugin):
precision = 16
precision = '16-mixed'

...

Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/common/precision_intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Since computation happens in FP16, there is a chance of numerical instability du

.. note::

When using TPUs, setting ``precision=16`` will enable bfloat16, the only supported half precision type on TPUs.
When using TPUs, setting ``precision='16-mixed'`` will enable bfloat16, the only supported half precision type on TPUs.

.. testcode::
:skipif: not torch.cuda.is_available()
Expand Down
4 changes: 2 additions & 2 deletions docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -926,10 +926,10 @@ Half precision, or mixed precision, is the combined use of 32 and 16 bit floatin
trainer = Trainer(precision=32)

# 16-bit precision
trainer = Trainer(precision=16, accelerator="gpu", devices=1) # works only on CUDA
trainer = Trainer(precision="16-mixed", accelerator="gpu", devices=1) # works only on CUDA

# bfloat16 precision
trainer = Trainer(precision="bf16")
trainer = Trainer(precision="bf16-mixed")

# 64-bit precision
trainer = Trainer(precision=64)
Expand Down
1 change: 0 additions & 1 deletion docs/source-pytorch/fabric/fundamentals/launch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ This is essentially the same as running ``python path/to/your/script.py``, but i
precision (``16-mixed`` or ``16``) or
bfloat16 precision (``bf16-mixed`` or
``bf16``)

--help Show this message and exit.


Expand Down
2 changes: 1 addition & 1 deletion examples/app_multi_node/train_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def run(self):
)

# 2. Create Fabric.
fabric = Fabric(strategy="ddp", precision=16)
fabric = Fabric(strategy="ddp", precision="16-mixed")
model, optimizer = fabric.setup(model, torch.optim.SGD(model.parameters(), lr=0.01))
criterion = torch.nn.MSELoss()

Expand Down
2 changes: 1 addition & 1 deletion examples/pl_hpu/mnist_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def configure_optimizers(self):
"accelerator": "hpu",
"devices": 1,
"max_epochs": 1,
"plugins": lazy_instance(HPUPrecisionPlugin, precision=16),
"plugins": lazy_instance(HPUPrecisionPlugin, precision="16-mixed"),
},
run=False,
save_config_kwargs={"overwrite": True},
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Renamed `strategy='tpu_spawn'` to `strategy='xla'` and `strategy='tpu_spawn_debug'` to `strategy='xla_debug'` ([#16781](https://github.com/Lightning-AI/lightning/pull/16781))


- Changed arguments for precision settings (from [64|32|16|bf16] to ["64-true"|"32-true"|"16-mixed"|"bf16-mixed"]) ([#16783](https://github.com/Lightning-AI/lightning/pull/16783))

### Deprecated

-
Expand Down
15 changes: 9 additions & 6 deletions src/lightning/pytorch/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,18 @@ class MixedPrecisionPlugin(PrecisionPlugin):
"""

def __init__(
self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
self,
precision: Literal["16-mixed", "bf16-mixed"],
device: str,
scaler: Optional[torch.cuda.amp.GradScaler] = None,
) -> None:
self.precision = cast(Literal["16", "bf16"], str(precision)) # type: ignore
if scaler is None and self.precision == "16":
self.precision = cast(Literal["16-mixed", "bf16-mixed"], str(precision))
if scaler is None and self.precision == "16-mixed":
with _patch_cuda_is_available():
# if possible, we defer CUDA initialization to support strategies that will attempt forks
scaler = torch.cuda.amp.GradScaler()
if scaler is not None and self.precision == "bf16":
raise MisconfigurationException(f"`precision='bf16'` does not use a scaler, found {scaler}.")
if scaler is not None and self.precision == "bf16-mixed":
raise MisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
self.device = device
self.scaler = scaler

Expand Down Expand Up @@ -97,7 +100,7 @@ def clip_gradients(
def autocast_context_manager(self) -> torch.autocast:
# the dtype could be automatically inferred but we need to manually set it due to a bug upstream
# https://github.com/pytorch/pytorch/issues/67233
return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half)
return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16-mixed" else torch.half)

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
Expand Down
10 changes: 4 additions & 6 deletions src/lightning/pytorch/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@

warning_cache = WarningCache()

_PRECISION_INPUT_INT = Literal[32, 16]
_PRECISION_INPUT_STR = Literal["32", "16", "bf16"]
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR]
_PRECISION_INPUT = Literal["32-true", "16-mixed", "bf16-mixed"]


class DeepSpeedPrecisionPlugin(PrecisionPlugin):
Expand All @@ -46,14 +44,14 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
If unsupported ``precision`` is provided.
"""

def __init__(self, precision: Literal["32", 32, "16", 16, "bf16"]) -> None:
supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT)
def __init__(self, precision: Literal["32-true", "16-mixed", "bf16-mixed"]) -> None:
supported_precision = get_args(_PRECISION_INPUT)
if precision not in supported_precision:
raise ValueError(
f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported."
f" `precision` must be one of: {supported_precision}."
)
self.precision = cast(_PRECISION_INPUT_STR, str(precision)) # type: ignore
self.precision = cast(_PRECISION_INPUT, str(precision))

def backward( # type: ignore[override]
self,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
class DoublePrecisionPlugin(PrecisionPlugin):
"""Plugin for training with double (``torch.float64``) precision."""

precision: Literal["64"] = "64" # type: ignore
precision: Literal["64-true"] = "64-true"

def connect(
self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
Expand Down
8 changes: 4 additions & 4 deletions src/lightning/pytorch/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ class FSDPMixedPrecisionPlugin(MixedPrecisionPlugin):
"""AMP for Fully Sharded Data Parallel (FSDP) Training."""

def __init__(
self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[ShardedGradScaler] = None
self, precision: Literal["16-mixed", "bf16-mixed"], device: str, scaler: Optional[ShardedGradScaler] = None
) -> None:
if not _TORCH_GREATER_EQUAL_1_12:
raise MisconfigurationException("`FSDPMixedPrecisionPlugin` is supported from PyTorch v1.12.0 onwards.")
super().__init__(
precision, device, scaler=(ShardedGradScaler() if scaler is None and str(precision) == "16" else None)
precision, device, scaler=(ShardedGradScaler() if scaler is None and str(precision) == "16-mixed" else None)
)

def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
Expand All @@ -52,9 +52,9 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
@property
def mixed_precision_config(self) -> Optional[MixedPrecision]:
assert MixedPrecision is not None
if self.precision == "16":
if self.precision == "16-mixed":
dtype = torch.float16
elif self.precision == "bf16":
elif self.precision == "bf16-mixed":
dtype = torch.bfloat16
else:
raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.")
Expand Down
12 changes: 5 additions & 7 deletions src/lightning/pytorch/plugins/precision/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast, Literal, Optional, Union
from typing import cast, Literal, Optional

from typing_extensions import get_args

Expand All @@ -22,9 +22,7 @@
if _HPU_AVAILABLE:
from habana_frameworks.torch.hpex import hmp

_PRECISION_INPUT_INT = Literal[32, 16]
_PRECISION_INPUT_STR = Literal["32", "16", "bf16"]
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR]
_PRECISION_INPUT = Literal["32-true", "16-mixed", "bf16-mixed"]


class HPUPrecisionPlugin(PrecisionPlugin):
Expand All @@ -48,14 +46,14 @@ def __init__(
) -> None:
if not _HPU_AVAILABLE:
raise MisconfigurationException("HPU precision plugin requires HPU devices.")
supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT)
supported_precision = get_args(_PRECISION_INPUT)
if precision not in supported_precision:
raise ValueError(
f"`Trainer(accelerator='hpu', precision={precision!r})` is not supported."
f" `precision` must be one of: {supported_precision}."
)
self.precision = cast(_PRECISION_INPUT_STR, str(precision)) # type: ignore
if self.precision in ("16", "bf16"):
self.precision = cast(_PRECISION_INPUT, str(precision))
if self.precision in ("16-mixed", "bf16-mixed"):
hmp.convert(
opt_level=opt_level, bf16_file_path=bf16_file_path, fp32_file_path=fp32_file_path, isVerbose=verbose
)
12 changes: 5 additions & 7 deletions src/lightning/pytorch/plugins/precision/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,25 @@

warning_cache = WarningCache()

_PRECISION_INPUT_INT = Literal[32, 16]
_PRECISION_INPUT_STR = Literal["32", "16"]
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR]
_PRECISION_INPUT = Literal["32-true", "16-mixed"]


class IPUPrecisionPlugin(PrecisionPlugin):
"""Precision plugin for IPU integration.

Raises:
ValueError:
If the precision is neither 16 nor 32.
If the precision is neither 16-mixed nor 32-true.
"""

def __init__(self, precision: Literal["32", 32, "16", 16]) -> None:
supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT)
def __init__(self, precision: Literal["32-true", "16-mixed"]) -> None:
supported_precision = get_args(_PRECISION_INPUT)
if precision not in supported_precision:
raise ValueError(
f"`Trainer(accelerator='ipu', precision={precision!r})` is not supported."
f" `precision` must be one of: {supported_precision}."
)
self.precision = cast(_PRECISION_INPUT_STR, str(precision)) # type: ignore
self.precision = cast(_PRECISION_INPUT, str(precision))

def backward( # type: ignore[override]
self,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/plugins/precision/tpu_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class TPUBf16PrecisionPlugin(TPUPrecisionPlugin):
"""Plugin that enables bfloats on TPUs."""

precision: Literal["bf16"] = "bf16" # type: ignore
precision: Literal["bf16-mixed"] = "bf16-mixed"

def connect(
self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
Expand Down
12 changes: 6 additions & 6 deletions src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def __init__(

Arguments:

zero_optimization: Enable ZeRO optimization. This is compatible with either `precision=16` or
`precision="bf16"`.
zero_optimization: Enable ZeRO optimization. This is compatible with either `precision="16-mixed"` or
`precision="bf16-mixed"`.

stage: Different stages of the ZeRO Optimizer. 0 is disabled,
1 is optimizer state partitioning, 2 is optimizer+gradient state partitioning,
Expand Down Expand Up @@ -505,9 +505,9 @@ def model_sharded_context(self) -> Generator[None, None, None]:
if self.zero_stage_3:
assert self._config_initialized

if self.precision_plugin.precision == "16":
if self.precision_plugin.precision == "16-mixed":
dtype = torch.float16
elif self.precision_plugin.precision == "bf16":
elif self.precision_plugin.precision == "bf16-mixed":
dtype = torch.bfloat16
else:
dtype = torch.float32
Expand Down Expand Up @@ -641,7 +641,7 @@ def _auto_select_batch_size(self) -> int:

def _format_precision_config(self) -> None:
assert isinstance(self.config, dict)
if self.precision_plugin.precision == "16":
if self.precision_plugin.precision == "16-mixed":
if "fp16" not in self.config:
# FP16 is a DeepSpeed standalone AMP implementation
rank_zero_info("Enabling DeepSpeed FP16.")
Expand All @@ -653,7 +653,7 @@ def _format_precision_config(self) -> None:
"hysteresis": self.hysteresis,
"min_loss_scale": self.min_loss_scale,
}
elif "bf16" not in self.config and self.precision_plugin.precision == "bf16":
elif "bf16" not in self.config and self.precision_plugin.precision == "bf16-mixed":
rank_zero_info("Enabling DeepSpeed BF16.")
self.config["bf16"] = {"enabled": True}

Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ class FSDPStrategy(ParallelStrategy):
algorithms to help backward communication and computation overlapping.
The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
mixed_precision:
Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16``
or BF16 if ``precision=bf16`` unless a config is passed in.
Mixed Precision config. By default, Lightning will enable FP16 if ``precision="16-mixed"``
or BF16 if ``precision="bf16-mixed"`` unless a config is passed in.
This is only available in PyTorch 1.12 and later.
activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation
checkpointing. This is typically your transformer block (including attention + feed-forward).
Expand Down
14 changes: 11 additions & 3 deletions src/lightning/pytorch/strategies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,17 @@ def _call_register_strategies(registry: _StrategyRegistry, base_module: str) ->
mod.register_strategies(registry)


def _fp_to_half(tensor: Tensor, precision: Literal["64", 64, "32", 32, "16", 16, "bf16"]) -> Tensor:
if str(precision) == "16":
def _fp_to_half(
tensor: Tensor,
precision: Literal[
"64-true",
"32-true",
"16-mixed",
"bf16-mixed",
],
) -> Tensor:
if str(precision) == "16-mixed":
return _convert_fp_tensor(tensor, torch.half)
if precision == "bf16":
if precision == "bf16-mixed":
return _convert_fp_tensor(tensor, torch.bfloat16)
return tensor
Loading