Skip to content

Commit

Permalink
imports
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Oct 1, 2022
1 parent 80d24fe commit cc65718
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/lightning_lite/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from lightning_lite.plugins.io.xla_plugin import XLACheckpointIO
from lightning_lite.plugins.precision.deepspeed import DeepSpeedPrecision
from lightning_lite.plugins.precision.double import DoublePrecision
from lightning_lite.plugins.precision.fsdp import FSDPPrecision
from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision
from lightning_lite.plugins.precision.precision import Precision
from lightning_lite.plugins.precision.tpu import TPUPrecision
Expand Down
1 change: 1 addition & 0 deletions src/lightning_lite/plugins/precision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from lightning_lite.plugins.precision.deepspeed import DeepSpeedPrecision
from lightning_lite.plugins.precision.double import DoublePrecision
from lightning_lite.plugins.precision.fsdp import FSDPPrecision
from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision
from lightning_lite.plugins.precision.precision import Precision
from lightning_lite.plugins.precision.tpu import TPUPrecision
Expand Down
4 changes: 3 additions & 1 deletion src/lightning_lite/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
class FSDPPrecision(NativeMixedPrecision):
"""AMP for Fully Sharded Data Parallel training."""

def __init__(self, precision: Literal[16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None) -> None:
def __init__(
self, precision: Literal[16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None
) -> None:
if not _TORCH_GREATER_EQUAL_1_12:
raise RuntimeError("`FSDPPrecision` is supported from PyTorch v1.12.0 onwards.")

Expand Down
1 change: 1 addition & 0 deletions src/lightning_lite/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from lightning_lite.strategies.dp import DataParallelStrategy # noqa: F401
from lightning_lite.strategies.fairscale import DDPShardedStrategy # noqa: F401
from lightning_lite.strategies.fairscale import DDPSpawnShardedStrategy # noqa: F401
from lightning_lite.strategies.fsdp import FSDPStrategy # noqa: F401
from lightning_lite.strategies.parallel import ParallelStrategy # noqa: F401
from lightning_lite.strategies.registry import _call_register_strategies, _StrategyRegistry
from lightning_lite.strategies.single_device import SingleDeviceStrategy # noqa: F401
Expand Down

0 comments on commit cc65718

Please sign in to comment.