From cc65718d947f7746970108bdd72a6fc637395eac Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 1 Oct 2022 17:16:01 +0200 Subject: [PATCH] imports --- src/lightning_lite/plugins/__init__.py | 1 + src/lightning_lite/plugins/precision/__init__.py | 1 + src/lightning_lite/plugins/precision/fsdp.py | 4 +++- src/lightning_lite/strategies/__init__.py | 1 + 4 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/lightning_lite/plugins/__init__.py b/src/lightning_lite/plugins/__init__.py index 54aa3a4e4e113..0d166904491be 100644 --- a/src/lightning_lite/plugins/__init__.py +++ b/src/lightning_lite/plugins/__init__.py @@ -18,6 +18,7 @@ from lightning_lite.plugins.io.xla_plugin import XLACheckpointIO from lightning_lite.plugins.precision.deepspeed import DeepSpeedPrecision from lightning_lite.plugins.precision.double import DoublePrecision +from lightning_lite.plugins.precision.fsdp import FSDPPrecision from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision from lightning_lite.plugins.precision.precision import Precision from lightning_lite.plugins.precision.tpu import TPUPrecision diff --git a/src/lightning_lite/plugins/precision/__init__.py b/src/lightning_lite/plugins/precision/__init__.py index 412ef9274822c..c390edd8e36f2 100644 --- a/src/lightning_lite/plugins/precision/__init__.py +++ b/src/lightning_lite/plugins/precision/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from lightning_lite.plugins.precision.deepspeed import DeepSpeedPrecision from lightning_lite.plugins.precision.double import DoublePrecision +from lightning_lite.plugins.precision.fsdp import FSDPPrecision from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision from lightning_lite.plugins.precision.precision import Precision from lightning_lite.plugins.precision.tpu import TPUPrecision diff --git a/src/lightning_lite/plugins/precision/fsdp.py b/src/lightning_lite/plugins/precision/fsdp.py index 4ef4bbfe168cf..45f38838774a8 100644 --- a/src/lightning_lite/plugins/precision/fsdp.py +++ b/src/lightning_lite/plugins/precision/fsdp.py @@ -27,7 +27,9 @@ class FSDPPrecision(NativeMixedPrecision): """AMP for Fully Sharded Data Parallel training.""" - def __init__(self, precision: Literal[16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None) -> None: + def __init__( + self, precision: Literal[16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None + ) -> None: if not _TORCH_GREATER_EQUAL_1_12: raise RuntimeError("`FSDPPrecision` is supported from PyTorch v1.12.0 onwards.") diff --git a/src/lightning_lite/strategies/__init__.py b/src/lightning_lite/strategies/__init__.py index f9cf74e30e4c0..a8d235708b573 100644 --- a/src/lightning_lite/strategies/__init__.py +++ b/src/lightning_lite/strategies/__init__.py @@ -17,6 +17,7 @@ from lightning_lite.strategies.dp import DataParallelStrategy # noqa: F401 from lightning_lite.strategies.fairscale import DDPShardedStrategy # noqa: F401 from lightning_lite.strategies.fairscale import DDPSpawnShardedStrategy # noqa: F401 +from lightning_lite.strategies.fsdp import FSDPStrategy # noqa: F401 from lightning_lite.strategies.parallel import ParallelStrategy # noqa: F401 from lightning_lite.strategies.registry import _call_register_strategies, _StrategyRegistry from lightning_lite.strategies.single_device import SingleDeviceStrategy # noqa: F401