From f6cad324c0ecb22b4d6d2b0c7e78a9949d9c692b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 16 Feb 2023 15:08:25 +0100 Subject: [PATCH 01/14] Use the local batch_idx to update the progress bar (#16760) --- src/lightning/pytorch/CHANGELOG.md | 4 ++ .../pytorch/callbacks/progress/base.py | 43 -------------- .../callbacks/progress/rich_progress.py | 10 ++-- .../callbacks/progress/tqdm_progress.py | 59 ++++++++++++++----- 4 files changed, 53 insertions(+), 63 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index a85d8d73c816d..988023827b3b9 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -297,6 +297,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed the unused `lightning.pytorch.utilities.supporters.{SharedCycleIteratorState,CombinedLoaderIterator}` classes ([#16714](https://github.com/Lightning-AI/lightning/pull/16714)) +- Removed `ProgressBarBase.{train_batch_idx,val_batch_idx,test_batch_idx,predict_batch_idx}` properties ([#16760](https://github.com/Lightning-AI/lightning/pull/16760)) + + + - Removed the `Trainer(track_grad_norm=...)` argument ([#16745](https://github.com/Lightning-AI/lightning/pull/16745)) diff --git a/src/lightning/pytorch/callbacks/progress/base.py b/src/lightning/pytorch/callbacks/progress/base.py index 041783ca68b1b..c0492d5bf314e 100644 --- a/src/lightning/pytorch/callbacks/progress/base.py +++ b/src/lightning/pytorch/callbacks/progress/base.py @@ -76,49 +76,6 @@ def test_description(self) -> str: def predict_description(self) -> str: return "Predicting" - @property - def _val_processed(self) -> int: - # use total in case validation runs more than once per training epoch - return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed - - @property - def train_batch_idx(self) -> int: - """The number of batches processed during training. - - Use this to update your progress bar. - """ - return self.trainer.fit_loop.epoch_loop.batch_progress.current.processed - - @property - def val_batch_idx(self) -> int: - """The number of batches processed during validation. - - Use this to update your progress bar. - """ - if self.trainer.state.fn == "fit": - loop = self.trainer.fit_loop.epoch_loop.val_loop - else: - loop = self.trainer.validate_loop - - current_batch_idx = loop.epoch_loop.batch_progress.current.processed - return current_batch_idx - - @property - def test_batch_idx(self) -> int: - """The number of batches processed during testing. - - Use this to update your progress bar. - """ - return self.trainer.test_loop.epoch_loop.batch_progress.current.processed - - @property - def predict_batch_idx(self) -> int: - """The number of batches processed during prediction. - - Use this to update your progress bar. - """ - return self.trainer.predict_loop.epoch_loop.batch_progress.current.processed - @property def total_train_batches(self) -> Union[int, float]: """The total number of training batches, which may change from epoch to epoch. diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 157350b63288c..a8f77a3a91dfc 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -485,7 +485,7 @@ def on_predict_batch_start( def on_train_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int ) -> None: - self._update(self.train_progress_bar_id, self.train_batch_idx) + self._update(self.train_progress_bar_id, batch_idx + 1) self._update_metrics(trainer, pl_module) self.refresh() @@ -504,9 +504,9 @@ def on_validation_batch_end( if self.is_disabled: return if trainer.sanity_checking: - self._update(self.val_sanity_progress_bar_id, self.val_batch_idx) + self._update(self.val_sanity_progress_bar_id, batch_idx + 1) elif self.val_progress_bar_id is not None: - self._update(self.val_progress_bar_id, self.val_batch_idx) + self._update(self.val_progress_bar_id, batch_idx + 1) self.refresh() def on_test_batch_end( @@ -521,7 +521,7 @@ def on_test_batch_end( if self.is_disabled: return assert self.test_progress_bar_id is not None - self._update(self.test_progress_bar_id, self.test_batch_idx) + self._update(self.test_progress_bar_id, batch_idx + 1) self.refresh() def on_predict_batch_end( @@ -536,7 +536,7 @@ def on_predict_batch_end( if self.is_disabled: return assert self.predict_progress_bar_id is not None - self._update(self.predict_progress_bar_id, self.predict_batch_idx) + self._update(self.predict_progress_bar_id, batch_idx + 1) self.refresh() def _get_train_description(self, current_epoch: int) -> str: diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index 121dccd0327bf..fe57b79c4fd02 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -17,6 +17,8 @@ import sys from typing import Any, Dict, Optional, Union +from lightning.pytorch.utilities.types import STEP_OUTPUT + # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed @@ -190,7 +192,6 @@ def init_train_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for training.""" bar = Tqdm( desc=self.train_description, - initial=self.train_batch_idx, position=(2 * self.process_position), disable=self.is_disabled, leave=True, @@ -204,7 +205,6 @@ def init_predict_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for predicting.""" bar = Tqdm( desc=self.predict_description, - initial=self.train_batch_idx, position=(2 * self.process_position), disable=self.is_disabled, leave=True, @@ -256,10 +256,12 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: self.train_progress_bar.initial = 0 self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}") - def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None: - current = self.train_batch_idx - if self._should_update(current, self.train_progress_bar.total): - _update_n(self.train_progress_bar, current) + def on_train_batch_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int + ) -> None: + n = batch_idx + 1 + if self._should_update(n, self.train_progress_bar.total): + _update_n(self.train_progress_bar, n) self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -289,9 +291,18 @@ def on_validation_batch_start( desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}") - def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None: - if self._should_update(self.val_batch_idx, self.val_progress_bar.total): - _update_n(self.val_progress_bar, self.val_batch_idx) + def on_validation_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: Optional[STEP_OUTPUT], + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + n = batch_idx + 1 + if self._should_update(n, self.val_progress_bar.total): + _update_n(self.val_progress_bar, n) def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if self._train_progress_bar is not None and trainer.state.fn == "fit": @@ -317,9 +328,18 @@ def on_test_batch_start( self.test_progress_bar.initial = 0 self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}") - def on_test_batch_end(self, *_: Any) -> None: - if self._should_update(self.test_batch_idx, self.test_progress_bar.total): - _update_n(self.test_progress_bar, self.test_batch_idx) + def on_test_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: Optional[STEP_OUTPUT], + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + n = batch_idx + 1 + if self._should_update(n, self.test_progress_bar.total): + _update_n(self.test_progress_bar, n) def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.test_progress_bar.close() @@ -343,9 +363,18 @@ def on_predict_batch_start( self.predict_progress_bar.initial = 0 self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}") - def on_predict_batch_end(self, *_: Any) -> None: - if self._should_update(self.predict_batch_idx, self.predict_progress_bar.total): - _update_n(self.predict_progress_bar, self.predict_batch_idx) + def on_predict_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + n = batch_idx + 1 + if self._should_update(n, self.predict_progress_bar.total): + _update_n(self.predict_progress_bar, n) def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.predict_progress_bar.close() From 57f2f1c0b407a5d37a52e8296759b700a3fad67d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 16 Feb 2023 15:08:51 +0100 Subject: [PATCH 02/14] Fix RunningStage properties for sanity checking (#16774) --- src/lightning/pytorch/trainer/states.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/trainer/states.py b/src/lightning/pytorch/trainer/states.py index 336a6c08b2778..73b7cb71dcf82 100644 --- a/src/lightning/pytorch/trainer/states.py +++ b/src/lightning/pytorch/trainer/states.py @@ -63,13 +63,11 @@ class RunningStage(LightningEnum): @property def evaluating(self) -> bool: - return self in (self.VALIDATING, self.TESTING) + return self in (self.VALIDATING, self.TESTING, self.SANITY_CHECKING) @property def dataloader_prefix(self) -> Optional[str]: - if self == self.SANITY_CHECKING: - return None - if self == self.VALIDATING: + if self in (self.VALIDATING, self.SANITY_CHECKING): return "val" return self.value From c9452df005ae176e640e40cfa600e5863b34b714 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 16 Feb 2023 15:09:23 +0100 Subject: [PATCH 03/14] Run XLA's dataloader validation per dataloader (#16775) --- src/lightning/fabric/strategies/xla.py | 18 ++++---- src/lightning/pytorch/strategies/strategy.py | 5 +-- src/lightning/pytorch/strategies/tpu_spawn.py | 42 ++++--------------- .../strategies/test_tpu_spawn.py | 36 ++-------------- 4 files changed, 21 insertions(+), 80 deletions(-) diff --git a/src/lightning/fabric/strategies/xla.py b/src/lightning/fabric/strategies/xla.py index ffaec51891412..a04fc0c1055f8 100644 --- a/src/lightning/fabric/strategies/xla.py +++ b/src/lightning/fabric/strategies/xla.py @@ -13,7 +13,7 @@ # limitations under the License. import io import os -from typing import Any, Dict, List, Mapping, Optional, Sequence, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union import torch from torch import Tensor @@ -30,7 +30,6 @@ from lightning.fabric.strategies import ParallelStrategy from lightning.fabric.strategies.launchers.xla import _XLALauncher from lightning.fabric.strategies.strategy import TBroadcast -from lightning.fabric.utilities.apply_func import apply_to_collection from lightning.fabric.utilities.data import has_len from lightning.fabric.utilities.rank_zero import rank_zero_only from lightning.fabric.utilities.types import _PATH, ReduceOp @@ -222,12 +221,9 @@ def _set_world_ranks(self) -> None: rank_zero_only.rank = self.cluster_environment.global_rank() @staticmethod - def _validate_dataloader(dataloaders: DataLoader) -> None: - def check_has_len(dataloader: DataLoader) -> None: - if not has_len(dataloader): - raise TypeError( - "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`." - " HINT: You can mock the length on your dataset to bypass this MisconfigurationException." - ) - - apply_to_collection(dataloaders, dtype=object, wrong_dtype=(Sequence, Mapping), function=check_has_len) + def _validate_dataloader(dataloader: object) -> None: + if not has_len(dataloader): + raise TypeError( + "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`." + " HINT: You can mock the length on your dataset to bypass this error." + ) diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index bb1ce96d62ebe..b3a2feb030343 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -14,13 +14,12 @@ import contextlib import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, TypeVar, Union import torch from torch import Tensor from torch.nn import Module from torch.optim import Optimizer -from torch.utils.data import DataLoader import lightning.pytorch as pl from lightning.fabric.plugins import CheckpointIO @@ -405,7 +404,7 @@ def validation_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: def test_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: return output - def process_dataloader(self, dataloader: DataLoader) -> DataLoader: + def process_dataloader(self, dataloader: Iterable) -> Iterable: """Wraps the dataloader if necessary. Args: diff --git a/src/lightning/pytorch/strategies/tpu_spawn.py b/src/lightning/pytorch/strategies/tpu_spawn.py index bcb5d669b7186..757bfa3c99fd3 100644 --- a/src/lightning/pytorch/strategies/tpu_spawn.py +++ b/src/lightning/pytorch/strategies/tpu_spawn.py @@ -13,13 +13,11 @@ # limitations under the License. import io import os -from typing import Any, Dict, List, Mapping, Optional, Sequence, TYPE_CHECKING, Union +from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Union import torch -from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor from torch.nn import Module -from torch.utils.data import DataLoader import lightning.pytorch as pl from lightning.fabric.accelerators.tpu import _XLA_AVAILABLE @@ -34,12 +32,10 @@ from lightning.pytorch.strategies.ddp_spawn import DDPSpawnStrategy from lightning.pytorch.strategies.launchers.xla import _XLALauncher from lightning.pytorch.strategies.strategy import TBroadcast -from lightning.pytorch.trainer.connectors.data_connector import DataConnector from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities import find_shared_parameters, set_shared_parameters -from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_zero_only -from lightning.pytorch.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS +from lightning.pytorch.utilities.types import STEP_OUTPUT if TYPE_CHECKING and _XLA_AVAILABLE: from torch_xla.distributed.parallel_loader import MpDeviceLoader @@ -98,34 +94,14 @@ def root_device(self) -> torch.device: return xm.xla_device() @staticmethod - def _validate_dataloader(dataloaders: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]) -> None: - def check_has_len(dataloader: DataLoader) -> None: - if not has_len(dataloader): - raise MisconfigurationException( - "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`." - " HINT: You can mock the length on your dataset to bypass this MisconfigurationException." - ) - - apply_to_collection(dataloaders, dtype=object, wrong_dtype=(Sequence, Mapping), function=check_has_len) - - @staticmethod - def _validate_patched_dataloaders(model: "pl.LightningModule") -> None: - """Validate and fail fast if the dataloaders were passed directly to fit.""" - connector: DataConnector = model.trainer._data_connector - sources = ( - connector._train_dataloader_source, - connector._val_dataloader_source, - connector._test_dataloader_source, - connector._predict_dataloader_source, - ) - for source in sources: - if not source.is_module(): - assert source.instance is not None - assert not isinstance(source.instance, (pl.LightningModule, pl.LightningDataModule)) - TPUSpawnStrategy._validate_dataloader(source.instance) + def _validate_dataloader(dataloader: object) -> None: + if not has_len(dataloader): + raise TypeError( + "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`." + " HINT: You can mock the length on your dataset to bypass this error." + ) def connect(self, model: "pl.LightningModule") -> None: - TPUSpawnStrategy._validate_patched_dataloaders(model) import torch_xla.distributed.xla_multiprocessing as xmp self.wrapped_model = xmp.MpModelWrapper(_LightningModuleWrapperBase(model)) @@ -166,7 +142,7 @@ def is_distributed(self) -> bool: return (xenv.HOST_WORLD_SIZE in os.environ) and self.world_size != 1 - def process_dataloader(self, dataloader: DataLoader) -> "MpDeviceLoader": + def process_dataloader(self, dataloader: Iterable) -> "MpDeviceLoader": TPUSpawnStrategy._validate_dataloader(dataloader) from torch_xla.distributed.parallel_loader import MpDeviceLoader diff --git a/tests/tests_pytorch/strategies/test_tpu_spawn.py b/tests/tests_pytorch/strategies/test_tpu_spawn.py index 73f14fa318813..ddc3c2c76b577 100644 --- a/tests/tests_pytorch/strategies/test_tpu_spawn.py +++ b/tests/tests_pytorch/strategies/test_tpu_spawn.py @@ -22,7 +22,6 @@ from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.strategies import TPUSpawnStrategy -from lightning.pytorch.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.dataloaders import CustomNotImplementedErrorDataloader from tests_pytorch.helpers.runif import RunIf @@ -45,39 +44,10 @@ def predict_dataloader(self): _loader_no_len = CustomNotImplementedErrorDataloader(_loader) -@pytest.mark.parametrize( - "train_dataloaders, val_dataloaders, test_dataloaders, predict_dataloaders", - [ - (_loader_no_len, None, None, None), - (None, _loader_no_len, None, None), - (None, None, _loader_no_len, None), - (None, None, None, _loader_no_len), - (None, [_loader, _loader_no_len], None, None), - ], -) -def test_error_iterable_dataloaders_passed_to_fit( - xla_available, train_dataloaders, val_dataloaders, test_dataloaders, predict_dataloaders -): - """Test that the TPUSpawnStrategy identifies dataloaders with iterable datasets and fails early.""" - trainer = Trainer() - model = BoringModelNoDataloaders() - model.trainer = trainer - - trainer._data_connector.attach_dataloaders( - model, - train_dataloaders=train_dataloaders, - val_dataloaders=val_dataloaders, - test_dataloaders=test_dataloaders, - predict_dataloaders=predict_dataloaders, - ) - - with pytest.raises(MisconfigurationException, match="TPUs do not currently support"): - TPUSpawnStrategy(MagicMock()).connect(model) - - def test_error_process_iterable_dataloader(xla_available): - with pytest.raises(MisconfigurationException, match="TPUs do not currently support"): - TPUSpawnStrategy(MagicMock()).process_dataloader(_loader_no_len) + strategy = TPUSpawnStrategy(MagicMock()) + with pytest.raises(TypeError, match="TPUs do not currently support"): + strategy.process_dataloader(_loader_no_len) class BoringModelTPU(BoringModel): From 51d44f57dd8de1822faa81a12eb7ca631ddfe504 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 16 Feb 2023 15:10:16 +0100 Subject: [PATCH 04/14] Prefetch if it's not a sized iterable (#16776) --- src/lightning/pytorch/loops/dataloader/evaluation_loop.py | 8 +------- src/lightning/pytorch/loops/fetchers.py | 8 ++++++-- src/lightning/pytorch/loops/fit_loop.py | 7 +------ src/lightning/pytorch/loops/utilities.py | 4 ++-- tests/tests_pytorch/loops/test_fetchers.py | 5 +++-- 5 files changed, 13 insertions(+), 19 deletions(-) diff --git a/src/lightning/pytorch/loops/dataloader/evaluation_loop.py b/src/lightning/pytorch/loops/dataloader/evaluation_loop.py index 24914733e3ec7..3e30da9adaeb1 100644 --- a/src/lightning/pytorch/loops/dataloader/evaluation_loop.py +++ b/src/lightning/pytorch/loops/dataloader/evaluation_loop.py @@ -74,12 +74,6 @@ def dataloaders(self) -> Sequence[DataLoader]: return [] return dataloaders - @property - def prefetch_batches(self) -> int: - batches = self.trainer.num_test_batches if self.trainer.testing else self.trainer.num_val_batches - is_unsized = batches[self.current_dataloader_idx] == float("inf") - return int(is_unsized) - @property def done(self) -> bool: """Returns whether all dataloaders are processed or evaluation should be skipped altogether.""" @@ -126,7 +120,7 @@ def reset(self) -> None: def on_run_start(self) -> None: """Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start`` hooks.""" - self._data_fetcher = _select_data_fetcher(self.trainer, prefetch_batches=self.prefetch_batches) + self._data_fetcher = _select_data_fetcher(self.trainer) # hook self._on_evaluation_model_eval() diff --git a/src/lightning/pytorch/loops/fetchers.py b/src/lightning/pytorch/loops/fetchers.py index 8d983b6e91f04..b7385eb6b5f1b 100644 --- a/src/lightning/pytorch/loops/fetchers.py +++ b/src/lightning/pytorch/loops/fetchers.py @@ -81,7 +81,7 @@ class _PrefetchDataFetcher(_DataFetcher): Args: prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track - whether a batch is the last one (available with :attr:`self.done`) under any training setup. + whether a batch is the last one (available with :attr:`self.done`) when the length is not available. """ def __init__(self, prefetch_batches: int = 1) -> None: @@ -98,6 +98,10 @@ def setup(self, dataloader: Iterable) -> None: def __iter__(self) -> "_PrefetchDataFetcher": super().__iter__() + if self._has_len: + # ignore pre-fetching, it's not necessary + return self + # prefetch batches to know when the iterator will be exhausted in advance iterator = self.dataloader_iter assert iterator is not None for _ in range(self.prefetch_batches): @@ -143,7 +147,7 @@ def _fetch_next_batch(self, iterator: Iterator) -> None: finally: self._stop_profiler() self.fetched += 1 - if not self.prefetch_batches and self._has_len: + if self._has_len: # when we don't prefetch but the dataloader is sized, we use the length for `done` dataloader = self.dataloader assert isinstance(dataloader, Sized) # `_has_len` is True diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 73fc7a37175e5..ef9f44c13f465 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -121,11 +121,6 @@ def restarting(self, restarting: bool) -> None: restarting = restarting and epoch_unfinished or self._iteration_based_training() _Loop.restarting.fset(self, restarting) # call the parent setter - @property - def prefetch_batches(self) -> int: - is_unsized = self.trainer.num_training_batches == float("inf") - return int(is_unsized) - @property def _skip_backward(self) -> bool: """Determines whether the loop will skip backward during automatic optimization.""" @@ -219,7 +214,7 @@ def on_run_start(self) -> None: if self.epoch_loop._should_check_val_epoch(): self.epoch_loop.val_loop._reload_evaluation_dataloaders() - self._data_fetcher = _select_data_fetcher(trainer, self.prefetch_batches) + self._data_fetcher = _select_data_fetcher(trainer) self._is_fresh_start_epoch = True self._results.to(device=trainer.lightning_module.device) diff --git a/src/lightning/pytorch/loops/utilities.py b/src/lightning/pytorch/loops/utilities.py index d61afe4b98d58..b6f61b75036b6 100644 --- a/src/lightning/pytorch/loops/utilities.py +++ b/src/lightning/pytorch/loops/utilities.py @@ -136,7 +136,7 @@ def _set_sampler_epoch(dataloader: Iterable, epoch: int) -> None: sampler.set_epoch(epoch) -def _select_data_fetcher(trainer: "pl.Trainer", prefetch_batches: int = 0) -> _DataFetcher: +def _select_data_fetcher(trainer: "pl.Trainer") -> _DataFetcher: lightning_module = trainer.lightning_module if trainer.testing: step_fx_name = "test_step" @@ -153,7 +153,7 @@ def _select_data_fetcher(trainer: "pl.Trainer", prefetch_batches: int = 0) -> _D "this signature is experimental and the behavior is subject to change." ) return _DataLoaderIterDataFetcher() - return _PrefetchDataFetcher(prefetch_batches=prefetch_batches) + return _PrefetchDataFetcher() def _no_grad_context(loop_run: Callable) -> Callable: diff --git a/tests/tests_pytorch/loops/test_fetchers.py b/tests/tests_pytorch/loops/test_fetchers.py index a5c1462f7f4aa..07c6c1507c8d3 100644 --- a/tests/tests_pytorch/loops/test_fetchers.py +++ b/tests/tests_pytorch/loops/test_fetchers.py @@ -64,8 +64,9 @@ def generate(): # we can only know the last batch with sized iterables or when we prefetch is_last_batch = [False, False, prefetch_batches > 0 or dataset_cls is SizedDataset] - fetched = list(range(prefetch_batches + 1, 4)) - fetched += [3] * (3 - len(fetched)) + fetched = ( + [1, 2, 3] if dataset_cls is SizedDataset else [1, 2, 3, 3, 3, 3, 3][prefetch_batches : prefetch_batches + 3] + ) batches = [[1, 1], [2, 2], [3, 3]] if use_combined_loader else [1, 2, 3] expected = list(zip(fetched, batches, is_last_batch)) assert len(expected) == 3 From cc22ddc716820fe4fcc297aefbee644acd0285e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 16 Feb 2023 15:25:56 +0100 Subject: [PATCH 05/14] Remove duplicate no_grad context managers (#16773) --- src/lightning/pytorch/loops/epoch/training_epoch_loop.py | 5 +---- src/lightning/pytorch/trainer/trainer.py | 3 +-- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py index cf7f38707ca46..1f52128f06954 100644 --- a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py @@ -15,8 +15,6 @@ from collections import OrderedDict from typing import Any, Dict, Optional, Union -import torch - import lightning.pytorch as pl from lightning.pytorch import loops # import as loops to avoid circular imports from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher @@ -284,8 +282,7 @@ def _run_validation(self) -> None: # reload dataloaders self.val_loop._reload_evaluation_dataloaders() - with torch.no_grad(): - self.val_loop.run() + self.val_loop.run() def _accumulated_batches_reached(self) -> bool: """Determine if accumulation will be finished by the end of the current batch.""" diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 2926982ecc94c..35beb7c57d662 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -956,8 +956,7 @@ def _run_sanity_check(self) -> None: ] # run eval step - with torch.no_grad(): - val_loop.run() + val_loop.run() call._call_callback_hooks(self, "on_sanity_check_end") From ad698f049bce0ac935c5f56f354bc20ec4f76e70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 16 Feb 2023 17:14:24 +0100 Subject: [PATCH 06/14] Update Colossal AI docs and integration (#16778) --- dockers/base-cuda/Dockerfile | 9 -- dockers/nvidia/Dockerfile | 2 - .../advanced/model_parallel.rst | 128 +----------------- .../advanced/third_party/colossalai.rst | 92 +++++++++++++ docs/source-pytorch/extensions/strategy.rst | 27 +++- requirements/pytorch/strategies.txt | 1 + src/lightning/pytorch/CHANGELOG.md | 3 + 7 files changed, 125 insertions(+), 137 deletions(-) create mode 100644 docs/source-pytorch/advanced/third_party/colossalai.rst diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index d2bd534e33776..424b82ce532ce 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -98,18 +98,9 @@ RUN \ pip install -r requirements/pytorch/base.txt --no-cache-dir --find-links https://download.pytorch.org/whl/cu${CUDA_VERSION_MM}/torch_stable.html && \ rm assistant.py -RUN \ - # install ColossalAI - # TODO: 1.13 wheels are not released, remove skip once they are - if [[ $PYTORCH_VERSION != "1.13" ]]; then \ - pip install "colossalai==0.2.4"; \ - python -c "import colossalai; print(colossalai.__version__)" ; \ - fi RUN \ # install rest of strategies - # remove colossalai from requirements since they are installed separately - python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'colossalai' not in line] ; open(fname, 'w').writelines(lines)" ; \ cat requirements/pytorch/strategies.txt && \ pip install -r requirements/pytorch/devel.txt -r requirements/pytorch/strategies.txt --no-cache-dir --find-links https://download.pytorch.org/whl/cu${CUDA_VERSION_MM}/torch_stable.html diff --git a/dockers/nvidia/Dockerfile b/dockers/nvidia/Dockerfile index 9bb97e92af04e..cb76595f3eac7 100644 --- a/dockers/nvidia/Dockerfile +++ b/dockers/nvidia/Dockerfile @@ -43,8 +43,6 @@ RUN \ # Installations \ pip install "Pillow>=8.2, !=8.3.0" "cryptography>=3.4" "py>=1.10" --no-cache-dir && \ - # remove colossalai from requirements since they are installed separately - python -c "fname = 'lightning/requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'colossalai' not in line] ; open(fname, 'w').writelines(lines)" ; \ PACKAGE_NAME=pytorch pip install './lightning[extra,loggers,strategies]' --no-cache-dir && \ rm -rf lightning && \ pip list diff --git a/docs/source-pytorch/advanced/model_parallel.rst b/docs/source-pytorch/advanced/model_parallel.rst index 9b3030f02ec8c..6603eae0da6c9 100644 --- a/docs/source-pytorch/advanced/model_parallel.rst +++ b/docs/source-pytorch/advanced/model_parallel.rst @@ -37,7 +37,7 @@ This means we cannot sacrifice throughput as much as if we were fine-tuning, bec Overall: * When **fine-tuning** a model, use advanced memory efficient strategies such as :ref:`fully-sharded-training`, :ref:`deepspeed-zero-stage-3` or :ref:`deepspeed-zero-stage-3-offload`, allowing you to fine-tune larger models if you are limited on compute -* When **pre-training** a model, use simpler optimizations such :ref:`sharded-training` or :ref:`deepspeed-zero-stage-2`, scaling the number of GPUs to reach larger parameter sizes +* When **pre-training** a model, use simpler optimizations such as :ref:`deepspeed-zero-stage-2`, scaling the number of GPUs to reach larger parameter sizes * For both fine-tuning and pre-training, use :ref:`deepspeed-activation-checkpointing` as the throughput degradation is not significant For example when using 128 GPUs, you can **pre-train** large 10 to 20 Billion parameter models using :ref:`deepspeed-zero-stage-2` without having to take a performance hit with more advanced optimized multi-gpu strategy. @@ -52,133 +52,17 @@ Sharding techniques help when model sizes are fairly large; roughly 500M+ parame * When your model is small (ResNet50 of around 80M Parameters), unless you are using unusually large batch sizes or inputs. * Due to high distributed communication between devices, if running on a slow network/interconnect, the training might be much slower than expected and then it's up to you to determince the tradeoff here. ----------- - -.. _colossalai: - -*********** -Colossal-AI -*********** - -:class:`~pytorch_lightning.strategies.colossalai.ColossalAIStrategy` implements ZeRO-DP with chunk-based memory management. -With this chunk mechanism, really large models can be trained with a small number of GPUs. -It supports larger trainable model size and batch size than usual heterogeneous training by reducing CUDA memory fragments and CPU memory consumption. -Also, it speeds up this kind of heterogeneous training by fully utilizing all kinds of resources. - -When enabling chunk mechanism, a set of consecutive parameters are stored in a chunk, and then the chunk is sharded across different processes. -This can reduce communication and data transmission frequency and fully utilize communication and PCI-E bandwidth, which makes training faster. - -Unlike traditional implementations, which adopt static memory partition, we implemented a dynamic heterogeneous memory management system named Gemini. -During the first training step, the warmup phase will sample the maximum non-model data memory (memory usage expect parameters, gradients, and optimizer states). -In later training, it will use the collected memory usage information to evict chunks dynamically. -Gemini allows you to fit much larger models with limited GPU memory. - -According to our benchmark results, we can train models with up to 24 billion parameters in 1 GPU. -You can install colossalai by consulting `how to download colossalai `_. -Then, run this benchmark in `Colossalai-PL/gpt `_. - -Here is an example showing how to use ColossalAI: - -.. code-block:: python - - from colossalai.nn.optimizer import HybridAdam - - - class MyBert(LightningModule): - ... - - def configure_sharded_model(self) -> None: - # create your model here - self.model = BertForSequenceClassification.from_pretrained("bert-base-uncased") - - def configure_optimizers(self): - # use the specified optimizer - optimizer = HybridAdam(self.model.parameters(), self.lr) - - ... - - - model = MyBert() - trainer = Trainer(accelerator="gpu", devices=1, precision=16, strategy="colossalai") - trainer.fit(model) - -You can find more examples in the `Colossalai-PL `_ repository. - -.. note:: - - * The only accelerator which ColossalAI supports is ``"gpu"``. But CPU resources will be used when the placement policy is set to "auto" or "cpu". - * The only precision which ColossalAI allows is 16 (FP16). +Cutting-edge and Experimental Strategies +======================================== - * It only supports a single optimizer, which must be ``colossalai.nn.optimizer.CPUAdam`` or ``colossalai.nn.optimizer. - HybridAdam`` now. You can set ``adamw_mode`` to False to use normal Adam. Noticing that ``HybridAdam`` is highly optimized, it uses fused CUDA kernel and parallel CPU kernel. - It is recomended to use ``HybridAdam``, since it updates parameters in GPU and CPU both. +Cutting-edge Lightning strategies are being developed by third-parties outside of Lightning. +If you want to be the first to try the latest and greatest experimental features for model-parallel training, check out the :doc:`Colossal-AI Strategy <./third_party/colossalai>` integration. - * Your model must be created using the :meth:`~pytorch_lightning.core.module.LightningModule.configure_sharded_model` method. - - * ``ColossalaiStrategy`` doesn't support gradient accumulation as of now. - -.. _colossal_placement_policy: - -Placement Policy -================ - -Placement policies can help users fully exploit their GPU-CPU heterogeneous memory space for better training efficiency. -There are three options for the placement policy. -They are "cpu", "cuda" and "auto" respectively. - -When the placement policy is set to "cpu", all participated parameters will be offloaded into CPU memory immediately at the end of every auto-grad operation. -In this way, "cpu" placement policy uses the least CUDA memory. -It is the best choice for users who want to exceptionally enlarge their model size or training batch size. - -When using "cuda" option, all parameters are placed in the CUDA memory, no CPU resources will be used during the training. -It is for users who get plenty of CUDA memory. - -The third option, "auto", enables Gemini. -It monitors the consumption of CUDA memory during the warmup phase and collects CUDA memory usage of all auto-grad operations. -In later training steps, Gemini automatically manages the data transmission between GPU and CPU according to collected CUDA memory usage information. -It is the fastest option when CUDA memory is enough. - -Here's an example of changing the placement policy to "cpu". - -.. code-block:: python - - from pytorch_lightning.strategies import ColossalAIStrategy - - model = MyModel() - my_strategy = ColossalAIStrategy(placement_policy="cpu") - trainer = Trainer(accelerator="gpu", devices=4, precision=16, strategy=my_strategy) - trainer.fit(model) - -.. _sharded-training: - -**************** -Sharded Training -**************** - -The technique can be found within `DeepSpeed ZeRO `_ and -`ZeRO-2 `_, -however the implementation is built from the ground up to be PyTorch compatible and standalone. -Sharded Training allows you to maintain GPU scaling efficiency, whilst reducing memory overhead drastically. In short, expect near-normal linear scaling (if your network allows), and significantly reduced memory usage when training large models. - -Sharded Training still utilizes Data Parallel Training under the hood, except optimizer states and gradients are sharded across GPUs. -This means the memory overhead per GPU is lower, as each GPU only has to maintain a partition of your optimizer state and gradients. - -The benefits vary by model and parameter sizes, but we've recorded up to a 63% memory reduction per GPU allowing us to double our model sizes. Because of efficient communication, -these benefits in multi-GPU setups are almost free and throughput scales well with multi-node setups. - -It is highly recommended to use Sharded Training in multi-GPU environments where memory is limited, or where training larger models are beneficial (500M+ parameter models). -A technical note: as batch size scales, storing activations for the backwards pass becomes the bottleneck in training. As a result, sharding optimizer state and gradients becomes less impactful. - -.. code-block:: python - - # train using Sharded DDP - trainer = Trainer(strategy="ddp_sharded") - -Internally we re-initialize your optimizers and shard them across your machines and processes. We handle all communication using PyTorch distributed, so no code changes are required. ---- + .. _fully-sharded-training: ********************** diff --git a/docs/source-pytorch/advanced/third_party/colossalai.rst b/docs/source-pytorch/advanced/third_party/colossalai.rst new file mode 100644 index 0000000000000..5223bdc0ad60d --- /dev/null +++ b/docs/source-pytorch/advanced/third_party/colossalai.rst @@ -0,0 +1,92 @@ +:orphan: + +########### +Colossal-AI +########### + + +The Colossal-AI strategy implements ZeRO-DP with chunk-based memory management. +With this chunk mechanism, really large models can be trained with a small number of GPUs. +It supports larger trainable model size and batch size than usual heterogeneous training by reducing CUDA memory fragments and CPU memory consumption. +Also, it speeds up this kind of heterogeneous training by fully utilizing all kinds of resources. + +When enabling chunk mechanism, a set of consecutive parameters are stored in a chunk, and then the chunk is sharded across different processes. +This can reduce communication and data transmission frequency and fully utilize communication and PCI-E bandwidth, which makes training faster. + +Unlike traditional implementations, which adopt static memory partition, we implemented a dynamic heterogeneous memory management system named Gemini. +During the first training step, the warmup phase will sample the maximum non-model data memory (memory usage expect parameters, gradients, and optimizer states). +In later training, it will use the collected memory usage information to evict chunks dynamically. +Gemini allows you to fit much larger models with limited GPU memory. + +According to our benchmark results, we can train models with up to 24 billion parameters in 1 GPU. + +You can install the Colossal-AI integration by running + +.. code-block:: bash + + pip install lightning-colossalai + +This will install both the `colossalai `_ package as well as the ``ColossalAIStrategy`` for the Lightning Trainer: + +.. code-block:: python + + trainer = Trainer(strategy="colossalai", precision=16, devices=...) + + +You can tune several settings by instantiating the strategy objects and pass options in: + +.. code-block:: python + + from lightning_colossalai import ColossalAIStrategy + + strategy = ColossalAIStrategy(...) + trainer = Trainer(strategy=strategy, precision=16, devices=...) + + +See a full example of a benchmark with the a `GPT-2 model `_ of up to 24 billion parameters + +.. note:: + + * The only accelerator which ColossalAI supports is ``"gpu"``. But CPU resources will be used when the placement policy is set to "auto" or "cpu". + + * The only precision which ColossalAI allows is 16-bit mixed precision (FP16). + + * It only supports a single optimizer, which must be ``colossalai.nn.optimizer.CPUAdam`` or ``colossalai.nn.optimizer. + HybridAdam`` now. You can set ``adamw_mode`` to False to use normal Adam. Noticing that ``HybridAdam`` is highly optimized, it uses fused CUDA kernel and parallel CPU kernel. + It is recomended to use ``HybridAdam``, since it updates parameters in GPU and CPU both. + + * Your model must be created using the :meth:`~pytorch_lightning.core.module.LightningModule.configure_sharded_model` method. + + * ``ColossalaiStrategy`` doesn't support gradient accumulation as of now. + +.. _colossal_placement_policy: + +Placement Policy +================ + +Placement policies can help users fully exploit their GPU-CPU heterogeneous memory space for better training efficiency. +There are three options for the placement policy. +They are "cpu", "cuda" and "auto" respectively. + +When the placement policy is set to "cpu", all participated parameters will be offloaded into CPU memory immediately at the end of every auto-grad operation. +In this way, "cpu" placement policy uses the least CUDA memory. +It is the best choice for users who want to exceptionally enlarge their model size or training batch size. + +When using "cuda" option, all parameters are placed in the CUDA memory, no CPU resources will be used during the training. +It is for users who get plenty of CUDA memory. + +The third option, "auto", enables Gemini. +It monitors the consumption of CUDA memory during the warmup phase and collects CUDA memory usage of all auto-grad operations. +In later training steps, Gemini automatically manages the data transmission between GPU and CPU according to collected CUDA memory usage information. +It is the fastest option when CUDA memory is enough. + +Here's an example of changing the placement policy to "cpu". + +.. code-block:: python + + from lightning_colossalai import ColossalAIStrategy + + model = MyModel() + my_strategy = ColossalAIStrategy(placement_policy="cpu") + trainer = Trainer(accelerator="gpu", devices=4, precision=16, strategy=my_strategy) + trainer.fit(model) diff --git a/docs/source-pytorch/extensions/strategy.rst b/docs/source-pytorch/extensions/strategy.rst index 429131ef03944..034d508474745 100644 --- a/docs/source-pytorch/extensions/strategy.rst +++ b/docs/source-pytorch/extensions/strategy.rst @@ -23,7 +23,7 @@ plugin and other optional plugins such as the :ref:`ClusterEnvironment `_ itself). ----------- +---- ***************************** Selecting a Built-in Strategy @@ -69,9 +69,6 @@ The below table lists all relevant strategies available in Lightning with their * - Name - Class - Description - * - colossalai - - :class:`~pytorch_lightning.strategies.ColossalAIStrategy` - - Colossal-AI provides a collection of parallel components for you. It aims to support you to write your distributed deep learning models just like how you write your model on your laptop. `Learn more. `__ * - fsdp - :class:`~pytorch_lightning.strategies.FSDPStrategy` - Strategy for Fully Sharded Data Parallel training. :ref:`Learn more. ` @@ -102,6 +99,28 @@ The below table lists all relevant strategies available in Lightning with their ---- + +********************** +Third-party Strategies +********************** + +There are powerful third-party strategies that integrate well with Lightning but aren't maintained as part of the ``lightning`` package. + +.. list-table:: List of third-party strategy implementations + :widths: 20 20 20 + :header-rows: 1 + + * - Name + - Package + - Description + * - colossalai + - `Lightning-AI/lightning-colossalai `_ + - Colossal-AI provides a collection of parallel components for you. It aims to support you to write your distributed deep learning models just like how you write your model on your laptop. `Learn more. `__ + + +---- + + ************************ Create a Custom Strategy ************************ diff --git a/requirements/pytorch/strategies.txt b/requirements/pytorch/strategies.txt index c8a5c9531fe3d..4db2eb301121b 100644 --- a/requirements/pytorch/strategies.txt +++ b/requirements/pytorch/strategies.txt @@ -2,3 +2,4 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment deepspeed>=0.6.0, <0.8.0 # TODO: Include 0.8.x after https://github.com/microsoft/DeepSpeed/commit/b587c7e85470329ac25df7c7c2521ff9b2833db7 gets released +lightning-colossalai==0.1.0dev diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 988023827b3b9..420cb9213dc87 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -310,6 +310,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed the `QuantizationAwareTraining` callback ([#16750](https://github.com/Lightning-AI/lightning/pull/16750)) +- Removed the `ColossalAIStrategy` and `ColossalAIPrecisionPlugin` in favor of the new [lightning-colossalai](https://github.com/Lightning-AI/lightning-colossalai) package ([#16757](https://github.com/Lightning-AI/lightning/pull/16757), [#16778](https://github.com/Lightning-AI/lightning/pull/16778)) + + ### Fixed - Fixed an attribute error and improved input validation for invalid strategy types being passed to Trainer ([#16693](https://github.com/Lightning-AI/lightning/pull/16693)) From 746c734e6a7c705d82bd9ecc91bef591d4ace358 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 16 Feb 2023 23:07:46 +0100 Subject: [PATCH 07/14] `SequentialMode` and `dataloader_iter` improvements (#16784) --- src/lightning/pytorch/CHANGELOG.md | 2 +- .../loops/epoch/evaluation_epoch_loop.py | 11 +-- .../loops/epoch/training_epoch_loop.py | 7 +- src/lightning/pytorch/loops/fetchers.py | 33 ++++++--- src/lightning/pytorch/trainer/supporters.py | 68 +++++++++++++------ .../tests_pytorch/trainer/test_supporters.py | 32 +++++++-- 6 files changed, 106 insertions(+), 47 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 420cb9213dc87..5250071080dad 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -42,7 +42,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a new method `Strategy.on_exception` to the strategy base interface ([#16646](https://github.com/Lightning-AI/lightning/pull/16646)) -- Added "sequential" mode support to `CombinedLoader` to consume multiple iterables in sequence ([#16743](https://github.com/Lightning-AI/lightning/pull/16743)) +- Added "sequential" mode support to `CombinedLoader` to consume multiple iterables in sequence ([#16743](https://github.com/Lightning-AI/lightning/pull/16743), [#16784](https://github.com/Lightning-AI/lightning/pull/16784)) ### Changed diff --git a/src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py b/src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py index 1e0c35ca29761..007068febf2ff 100644 --- a/src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py +++ b/src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py @@ -114,11 +114,12 @@ def advance( Raises: StopIteration: If the current batch is None """ - if not isinstance(data_fetcher, _DataLoaderIterDataFetcher): - batch_idx = self.batch_progress.current.ready - batch = next(data_fetcher) - else: - batch_idx, batch = next(data_fetcher) + batch_idx = ( + data_fetcher.fetched + if isinstance(data_fetcher, _DataLoaderIterDataFetcher) + else self.batch_progress.current.ready + ) + batch = next(data_fetcher) self.batch_progress.is_last_batch = data_fetcher.done dataloader_idx = kwargs.get("dataloader_idx", 0) diff --git a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py index 1f52128f06954..26d801bf21ad2 100644 --- a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py @@ -186,11 +186,8 @@ def advance(self, data_fetcher: _DataFetcher) -> None: # we are going to train first so the val loop does not need to restart self.val_loop.restarting = False - if not isinstance(data_fetcher, _DataLoaderIterDataFetcher): - batch_idx = self.batch_idx + 1 - batch = next(data_fetcher) - else: - batch_idx, batch = next(data_fetcher) + batch_idx = data_fetcher.fetched if isinstance(data_fetcher, _DataLoaderIterDataFetcher) else self.batch_idx + 1 + batch = next(data_fetcher) self.batch_progress.is_last_batch = data_fetcher.done trainer = self.trainer diff --git a/src/lightning/pytorch/loops/fetchers.py b/src/lightning/pytorch/loops/fetchers.py index b7385eb6b5f1b..ee29e7b69c0f7 100644 --- a/src/lightning/pytorch/loops/fetchers.py +++ b/src/lightning/pytorch/loops/fetchers.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterable, Iterator, List, Optional, Sized, Tuple +from typing import Any, Iterable, Iterator, List, Optional, Sized, Tuple, Union from torch.utils.data.dataloader import DataLoader from lightning.fabric.utilities.data import has_len -from lightning.pytorch.trainer.supporters import _shutdown_workers_and_reset_iterator, CombinedLoader +from lightning.pytorch.trainer.supporters import _Sequential, _shutdown_workers_and_reset_iterator, CombinedLoader from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -175,15 +175,24 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None: def __iter__(self) -> "_DataLoaderIterDataFetcher": super().__iter__() - iterator = self.dataloader_iter - assert iterator is not None self.iterator = iter(_DataFetcherWrapper(self)) return self - def __next__(self) -> Tuple[int, Iterator]: - if not self.done: - return self.fetched, self.iterator - raise StopIteration + def __next__(self) -> Union["_DataFetcherWrapper", Tuple["_DataFetcherWrapper", int, int]]: + if self.done: + raise StopIteration + assert isinstance(self.iterator, _DataFetcherWrapper) + if self._is_sequential: + sequential_mode = self.dataloader._iterator + assert isinstance(sequential_mode, _Sequential) + batch_idx = sequential_mode._idx + dataloader_idx = sequential_mode._iterator_idx + return self.iterator, batch_idx, dataloader_idx + return self.iterator + + @property + def _is_sequential(self) -> bool: + return isinstance(self.dataloader, CombinedLoader) and self.dataloader._mode == "sequential" class _DataFetcherWrapper(Iterator): @@ -191,4 +200,10 @@ def __init__(self, data_fetcher: _DataLoaderIterDataFetcher) -> None: self.data_fetcher = data_fetcher def __next__(self) -> Any: - return super(_DataLoaderIterDataFetcher, self.data_fetcher).__next__() + out = super(_DataLoaderIterDataFetcher, self.data_fetcher).__next__() + if self.data_fetcher._is_sequential: + # avoid breaking change with sequential mode and dataloader_iter. this is okay because + # dataloader_iter + sequential + multiple dataloaders is not supported so the `*_step(..., batch_idx)` value + # and the batch_index we are excluding here will match + return out[0] + return out diff --git a/src/lightning/pytorch/trainer/supporters.py b/src/lightning/pytorch/trainer/supporters.py index 2c3872e239dd5..ffa56538adc60 100644 --- a/src/lightning/pytorch/trainer/supporters.py +++ b/src/lightning/pytorch/trainer/supporters.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Iterable -from typing import Any, Callable, Iterator, List, Literal, Optional, Sized, Tuple, Type, TypeVar +from typing import Any, Callable, Iterator, List, Literal, Optional, Sized, Tuple, Type, TypeVar, Union from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter from typing_extensions import Self, TypedDict @@ -74,27 +74,47 @@ def __next__(self) -> List: return [next(it) for it in self.iterators] -class _Sequential(_ModeIterator[Tuple[int, Any]]): - def __init__(self, iterables: List[Iterable]) -> None: +class _Sequential(_ModeIterator[Tuple[Any, int, int]]): + def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, float]]] = None) -> None: super().__init__(iterables) self._iterator_idx = 0 # what would be dataloader_idx self._idx = 0 # what would be batch_idx + self.limits = limits - def __next__(self) -> Tuple[int, Any]: + @property + def limits(self) -> Optional[List[Union[int, float]]]: + """Optional limits per iterator.""" + return self._limits + + @limits.setter + def limits(self, limits: Optional[List[Union[int, float]]]) -> None: + if limits is not None and len(limits) != len(self.iterables): + raise ValueError( + f"Mismatch in number of limits ({len(limits)}) and number of iterables ({len(self.iterables)})" + ) + self._limits = limits + + def __next__(self) -> Tuple[Any, int, int]: n = len(self.iterators) - if n == 0: + if n == 0 or self._iterator_idx >= n: raise StopIteration + + # if limits are set, go to the correct iterator + if self.limits is not None: + while self.limits[self._iterator_idx] <= self._idx: + self._use_next_iterator() + if self._iterator_idx >= n: + raise StopIteration + try: out = next(self.iterators[self._iterator_idx]) index = self._idx self._idx += 1 - # the return is enumerated by default - return index, out + # batch, batch_idx, dataloader_idx + return out, index, self._iterator_idx except StopIteration: - self._iterator_idx += 1 - self._idx = 0 - if self._iterator_idx >= n: - raise + # try the next iterator + self._use_next_iterator() return self.__next__() def __iter__(self) -> Self: # type: ignore[valid-type] @@ -108,6 +128,10 @@ def reset(self) -> None: self._iterator_idx = 0 self._idx = 0 + def _use_next_iterator(self) -> None: + self._iterator_idx += 1 + self._idx = 0 + class _CombinationMode(TypedDict): fn: Callable[[List[int]], int] @@ -170,28 +194,28 @@ class CombinedLoader(Iterable): >>> combined_loader = CombinedLoader(iterables, 'max_size_cycle') >>> len(combined_loader) 3 - >>> for item in combined_loader: - ... print(item) + >>> for batch in combined_loader: + ... print(batch) {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} {'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])} >>> combined_loader = CombinedLoader(iterables, 'min_size') >>> len(combined_loader) 2 - >>> for item in combined_loader: - ... print(item) + >>> for batch in combined_loader: + ... print(batch) {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} >>> combined_loader = CombinedLoader(iterables, 'sequential') >>> len(combined_loader) 5 - >>> for item in combined_loader: - ... print(*item) - 0 tensor([0, 1, 2, 3]) - 1 tensor([4, 5]) - 0 tensor([0, 1, 2, 3, 4]) - 1 tensor([5, 6, 7, 8, 9]) - 2 tensor([10, 11, 12, 13, 14]) + >>> for batch, batch_idx, dataloader_idx in combined_loader: + ... print(f"{batch} {batch_idx=} {dataloader_idx=}") + tensor([0, 1, 2, 3]) batch_idx=0 dataloader_idx=0 + tensor([4, 5]) batch_idx=1 dataloader_idx=0 + tensor([0, 1, 2, 3, 4]) batch_idx=0 dataloader_idx=1 + tensor([5, 6, 7, 8, 9]) batch_idx=1 dataloader_idx=1 + tensor([10, 11, 12, 13, 14]) batch_idx=2 dataloader_idx=1 """ def __init__(self, iterables: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") -> None: diff --git a/tests/tests_pytorch/trainer/test_supporters.py b/tests/tests_pytorch/trainer/test_supporters.py index 01025975a248f..08af8ca7148e8 100644 --- a/tests/tests_pytorch/trainer/test_supporters.py +++ b/tests/tests_pytorch/trainer/test_supporters.py @@ -122,13 +122,14 @@ def test_combined_loader_modes(): combined_loader = CombinedLoader(iterables, "sequential") assert combined_loader._iterator is None assert len(combined_loader) == sum_len - for total_idx, (idx, item) in enumerate(combined_loader): + for total_idx, (item, batch_idx, dataloader_idx) in enumerate(combined_loader): assert isinstance(combined_loader._iterator, _Sequential) - assert isinstance(idx, int) + assert isinstance(batch_idx, int) assert isinstance(item, Tensor) assert idx == lengths[-1] - 1 assert total_idx == sum_len - 1 assert total_idx == len(combined_loader) - 1 + assert dataloader_idx == len(iterables) - 1 iterables = list(iterables.values()) @@ -156,13 +157,14 @@ def test_combined_loader_modes(): combined_loader = CombinedLoader(iterables, "sequential") assert combined_loader._iterator is None assert len(combined_loader) == sum_len - for total_idx, (idx, item) in enumerate(combined_loader): + for total_idx, (item, batch_idx, dataloader_idx) in enumerate(combined_loader): assert isinstance(combined_loader._iterator, _Sequential) - assert isinstance(idx, int) + assert isinstance(batch_idx, int) assert isinstance(item, Tensor) assert idx == lengths[-1] - 1 assert total_idx == sum_len - 1 assert total_idx == len(combined_loader) - 1 + assert dataloader_idx == len(iterables) - 1 def test_combined_loader_raises(): @@ -205,7 +207,6 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader has_break = False for idx, item in enumerate(combined_loader): assert isinstance(item, Sequence) - assert len(item) == 2 if use_multiple_dataloaders else 1 if not use_multiple_dataloaders and idx == 4: has_break = True break @@ -221,6 +222,27 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader assert idx == expected - 1 +@pytest.mark.parametrize( + ("limits", "expected"), + [ + (None, [("a", 0, 0), ("b", 1, 0), ("c", 2, 0), ("d", 0, 1), ("e", 1, 1)]), + ([1, 0], [("a", 0, 0)]), + ([0, float("inf")], [("d", 0, 1), ("e", 1, 1)]), + ([1, 1], [("a", 0, 0), ("d", 0, 1)]), + ], +) +def test_sequential_mode_limits(limits, expected): + iterable1 = ["a", "b", "c"] + iterable2 = ["d", "e"] + iterator = _Sequential([iterable1, iterable2], limits) + assert list(iterator) == expected + + +def test_sequential_mode_limits_raises(): + with pytest.raises(ValueError, match=r"number of limits \(0\) and number of iterables \(2\)"): + _Sequential([0, 1], []) + + @pytest.mark.parametrize("lengths", [[4, 6], [5, 5], [6, 4]]) def test_combined_loader_sequence_with_map_and_iterable(lengths): class MyIterableDataset(IterableDataset): From d27881e388b449c75ab18087dd2b6d3dae3baedd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 16 Feb 2023 23:08:13 +0100 Subject: [PATCH 08/14] Fix `set_epoch` not getting called for prediction dataloaders (#16785) --- src/lightning/pytorch/CHANGELOG.md | 2 ++ .../pytorch/overrides/distributed.py | 4 +++ .../loops/test_prediction_loop.py | 33 +++++++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 5250071080dad..a4f38d2f66c62 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -321,6 +321,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed early stopping triggering extra validation runs after reaching `min_epochs` or `min_steps` ([#16719](https://github.com/Lightning-AI/lightning/pull/16719)) +- Fixed bug where `set_epoch` was not called for prediction dataloaders ([#16785](https://github.com/Lightning-AI/lightning/pull/16785)) + ## [1.9.1] - 2023-02-10 ### Fixed diff --git a/src/lightning/pytorch/overrides/distributed.py b/src/lightning/pytorch/overrides/distributed.py index 3f8f98afb4a13..8830c223c622a 100644 --- a/src/lightning/pytorch/overrides/distributed.py +++ b/src/lightning/pytorch/overrides/distributed.py @@ -135,3 +135,7 @@ def batch_size(self) -> int: @property def sampler(self) -> Union[Sampler, Iterable]: return self._sampler.sampler + + def set_epoch(self, epoch: int) -> None: + if hasattr(self._sampler, "set_epoch"): + self._sampler.set_epoch(epoch) diff --git a/tests/tests_pytorch/loops/test_prediction_loop.py b/tests/tests_pytorch/loops/test_prediction_loop.py index 5d1de82f6536b..1b5e05c502e5a 100644 --- a/tests/tests_pytorch/loops/test_prediction_loop.py +++ b/tests/tests_pytorch/loops/test_prediction_loop.py @@ -1,3 +1,19 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock +from unittest.mock import call + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel @@ -30,3 +46,20 @@ def predict_step(self, batch, batch_idx): predictions = trainer.predict(model, return_predictions=False) assert predictions is None assert trainer.predict_loop.predictions == [] + + +def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path): + """Tests that set_epoch is called on the dataloader's batch sampler (if any) during prediction.""" + model = BoringModel() + trainer = Trainer( + default_root_dir=tmp_path, + limit_predict_batches=1, + enable_model_summary=False, + enable_checkpointing=False, + logger=False, + ) + trainer.fit_loop.epoch_progress.current.processed = 2 + + with mock.patch("lightning.pytorch.overrides.distributed.IndexBatchSamplerWrapper.set_epoch") as set_epoch_mock: + trainer.predict(model) + assert set_epoch_mock.mock_calls == [call(2)] From 57c1138525ec36d7676e65048048d5773d9ac661 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 17 Feb 2023 00:45:09 +0100 Subject: [PATCH 09/14] Trigger colossalai integration test in CI (#16789) --- .../trainer/connectors/test_accelerator_connector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 4621648e7c201..7ddfac98d8abe 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -837,6 +837,7 @@ def get_defaults(cls): assert connector_default == trainer_defaults[name] +@RunIf(min_cuda_gpus=1) # trigger this test on our GPU pipeline, because we don't install the package on the CPU suite @pytest.mark.skipif(not package_available("lightning_colossalai"), reason="Requires Colossal AI Strategy") def test_colossalai_external_strategy(monkeypatch): with mock.patch( From 6e359dcc86c6e12ebcbaf1cdddab989d90bd52c7 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 17 Feb 2023 01:52:46 +0000 Subject: [PATCH 10/14] [App] Fix idle timeout e2e (#16786) --- src/lightning/app/core/work.py | 5 ++- .../integrations_app/apps/idle_timeout/app.py | 31 +++++++++---------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/lightning/app/core/work.py b/src/lightning/app/core/work.py index abe57ab45f38c..1eb7cacbc1fa6 100644 --- a/src/lightning/app/core/work.py +++ b/src/lightning/app/core/work.py @@ -639,7 +639,10 @@ def _aggregate_status_timeout(self, statuses: List[Dict]) -> WorkStatus: return WorkStatus(**status, count=len(timeout_statuses)) def on_exit(self): - """Override this hook to add your logic when the work is exiting.""" + """Override this hook to add your logic when the work is exiting. + + Note: This hook is not guaranteed to be called when running in the cloud. + """ pass def stop(self): diff --git a/tests/integrations_app/apps/idle_timeout/app.py b/tests/integrations_app/apps/idle_timeout/app.py index 31e0d7c124ab6..d33df0a616d58 100644 --- a/tests/integrations_app/apps/idle_timeout/app.py +++ b/tests/integrations_app/apps/idle_timeout/app.py @@ -2,7 +2,7 @@ from lightning.app import CloudCompute, LightningApp, LightningFlow, LightningWork from lightning.app.storage.path import _artifacts_path, _filesystem -from lightning.app.utilities.enum import WorkStageStatus, WorkStopReasons +from lightning.app.utilities.enum import WorkStageStatus class SourceFileWriterWork(LightningWork): @@ -35,22 +35,21 @@ def run(self): if self.work.counter == 0: self.work.run() - elif ( - self.work.status.stage == WorkStageStatus.STOPPED - and self.work.status.reason == WorkStopReasons.SIGTERM_SIGNAL_HANDLER - and self.make_check - ): - succeeded_status = self.work.statuses[-3] - stopped_status_pending = self.work.statuses[-2] - stopped_status_sigterm = self.work.statuses[-1] - assert succeeded_status.stage == WorkStageStatus.SUCCEEDED - assert stopped_status_pending.stage == WorkStageStatus.STOPPED - assert stopped_status_pending.reason == WorkStopReasons.PENDING - assert stopped_status_sigterm.stage == WorkStageStatus.STOPPED - assert stopped_status_sigterm.reason == WorkStopReasons.SIGTERM_SIGNAL_HANDLER + elif self.work.status.stage == WorkStageStatus.STOPPED and self.make_check: + succeeded_statuses = [status for status in self.work.statuses if status.stage == WorkStageStatus.SUCCEEDED] + # Ensure the work succeeded at some point + assert len(succeeded_statuses) > 0 + succeeded_status = succeeded_statuses[-1] + + stopped_statuses = [status for status in self.work.statuses if status.stage == WorkStageStatus.STOPPED] + + # We want to check that the work started shutting down withing the required timeframe, so we take the first + # status that has `stage == STOPPED`. + stopped_status = stopped_statuses[0] + # Note: Account for the controlplane, k8s, SIGTERM handler delays. - assert (stopped_status_pending.timestamp - succeeded_status.timestamp) < 20 - assert (stopped_status_sigterm.timestamp - stopped_status_pending.timestamp) < 120 + assert (stopped_status.timestamp - succeeded_status.timestamp) < 20 + fs = _filesystem() destination_path = _artifacts_path(self.work) / pathlib.Path(*self.work.path.resolve().parts[1:]) assert fs.exists(destination_path) From 91e692c7673d86fe6c8fbd5f378899f3eb1ed23a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 17 Feb 2023 03:06:24 +0100 Subject: [PATCH 11/14] Rename the TPUSpawnStrategy to XLAStrategy (#16781) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/source-pytorch/accelerators/tpu_faq.rst | 6 ++--- .../advanced/strategy_registry.rst | 2 +- docs/source-pytorch/api_references.rst | 2 +- docs/source-pytorch/extensions/strategy.rst | 4 ++-- src/lightning/fabric/CHANGELOG.md | 3 ++- src/lightning/fabric/connector.py | 2 +- src/lightning/fabric/strategies/xla.py | 2 -- src/lightning/pytorch/CHANGELOG.md | 6 +++++ src/lightning/pytorch/strategies/__init__.py | 2 +- .../pytorch/strategies/launchers/xla.py | 2 +- .../strategies/{tpu_spawn.py => xla.py} | 13 ++++------ .../connectors/accelerator_connector.py | 10 ++++---- .../tests_fabric/strategies/test_registry.py | 1 - tests/tests_pytorch/accelerators/test_tpu.py | 24 +++++++++---------- tests/tests_pytorch/conftest.py | 2 +- tests/tests_pytorch/models/test_tpu.py | 10 ++++---- .../tests_pytorch/strategies/test_registry.py | 10 ++++---- .../{test_tpu_spawn.py => test_xla.py} | 10 ++++---- .../connectors/test_accelerator_connector.py | 2 +- .../test_estimated_stepping_batches.py | 2 +- 20 files changed, 58 insertions(+), 57 deletions(-) rename src/lightning/pytorch/strategies/{tpu_spawn.py => xla.py} (96%) rename tests/tests_pytorch/strategies/{test_tpu_spawn.py => test_xla.py} (89%) diff --git a/docs/source-pytorch/accelerators/tpu_faq.rst b/docs/source-pytorch/accelerators/tpu_faq.rst index f38f0a865b4cd..de4cd315e4cdb 100644 --- a/docs/source-pytorch/accelerators/tpu_faq.rst +++ b/docs/source-pytorch/accelerators/tpu_faq.rst @@ -61,7 +61,7 @@ How to resolve the replication issue? .format(len(local_devices), len(kind_devices))) RuntimeError: Cannot replicate if number of devices (1) is different from 8 -This error is raised when the XLA device is called outside the spawn process. Internally in `TPUSpawn` Strategy for training on multiple tpu cores, we use XLA's `xmp.spawn`. +This error is raised when the XLA device is called outside the spawn process. Internally in the XLA-Strategy for training on multiple tpu cores, we use XLA's `xmp.spawn`. Don't use ``xm.xla_device()`` while working on Lightning + TPUs! ---- @@ -91,7 +91,7 @@ How to setup the debug mode for Training on TPUs? import pytorch_lightning as pl my_model = MyLightningModule() - trainer = pl.Trainer(accelerator="tpu", devices=8, strategy="tpu_spawn_debug") + trainer = pl.Trainer(accelerator="tpu", devices=8, strategy="xla_debug") trainer.fit(my_model) Example Metrics report: @@ -108,7 +108,7 @@ Example Metrics report: A lot of PyTorch operations aren't lowered to XLA, which could lead to significant slowdown of the training process. These operations are moved to the CPU memory and evaluated, and then the results are transferred back to the XLA device(s). -By using the `tpu_spawn_debug` Strategy, users could create a metrics report to diagnose issues. +By using the `xla_debug` Strategy, users could create a metrics report to diagnose issues. The report includes things like (`XLA Reference `_): diff --git a/docs/source-pytorch/advanced/strategy_registry.rst b/docs/source-pytorch/advanced/strategy_registry.rst index 27bab6ea49df4..914db517eb121 100644 --- a/docs/source-pytorch/advanced/strategy_registry.rst +++ b/docs/source-pytorch/advanced/strategy_registry.rst @@ -18,7 +18,7 @@ It also returns the optional description and parameters for initialising the Str trainer = Trainer(strategy="deepspeed_stage_3_offload", accelerator="gpu", devices=3) # Training with the TPU Spawn Strategy with `debug` as True - trainer = Trainer(strategy="tpu_spawn_debug", accelerator="tpu", devices=8) + trainer = Trainer(strategy="xla_debug", accelerator="tpu", devices=8) Additionally, you can pass your custom registered training strategies to the ``strategy`` argument. diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index 7e4eb3ca8863f..8187a74ff49fd 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -222,7 +222,7 @@ strategies SingleHPUStrategy SingleTPUStrategy Strategy - TPUSpawnStrategy + XLAStrategy tuner ----- diff --git a/docs/source-pytorch/extensions/strategy.rst b/docs/source-pytorch/extensions/strategy.rst index 034d508474745..6b7474204e6bd 100644 --- a/docs/source-pytorch/extensions/strategy.rst +++ b/docs/source-pytorch/extensions/strategy.rst @@ -90,8 +90,8 @@ The below table lists all relevant strategies available in Lightning with their * - ipu_strategy - :class:`~pytorch_lightning.strategies.IPUStrategy` - Plugin for training on IPU devices. :doc:`Learn more. <../accelerators/ipu>` - * - tpu_spawn - - :class:`~pytorch_lightning.strategies.TPUSpawnStrategy` + * - xla + - :class:`~pytorch_lightning.strategies.XLAStrategy` - Strategy for training on multiple TPU devices using the :func:`torch_xla.distributed.xla_multiprocessing.spawn` method. :doc:`Learn more. <../accelerators/tpu>` * - single_tpu - :class:`~pytorch_lightning.strategies.SingleTPUStrategy` diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 77d93955929ee..148dd7dff4f25 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -28,7 +28,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `DataParallelStrategy.get_module_state_dict()` and `DDPStrategy.get_module_state_dict()` now correctly extracts the state dict without keys prefixed with 'module' ([#16487](https://github.com/Lightning-AI/lightning/pull/16487)) - - "Native" suffix removal ([#16490](https://github.com/Lightning-AI/lightning/pull/16490)) * `strategy="fsdp_full_shard_offload"` is now `strategy="fsdp_cpu_offload"` * `lightning.fabric.plugins.precision.native_amp` is now `lightning.fabric.plugins.precision.amp` @@ -36,6 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Enabled all shorthand strategy names that can be supported in the CLI ([#16485](https://github.com/Lightning-AI/lightning/pull/16485)) +- Renamed `strategy='tpu_spawn'` to `strategy='xla'` and `strategy='tpu_spawn_debug'` to `strategy='xla_debug'` ([#16781](https://github.com/Lightning-AI/lightning/pull/16781)) + ### Deprecated diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index b0e997dcf83a2..2eb29450bc82a 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -385,7 +385,7 @@ def _choose_and_init_cluster_environment(self) -> ClusterEnvironment: def _choose_strategy(self) -> Union[Strategy, str]: if self._accelerator_flag == "tpu": if self._parallel_devices and len(self._parallel_devices) > 1: - return "tpu_spawn" + return "xla" else: # TODO: lazy initialized device, then here could be self._strategy_flag = "single_tpu_device" return SingleTPUStrategy(device=self._parallel_devices[0]) # type: ignore diff --git a/src/lightning/fabric/strategies/xla.py b/src/lightning/fabric/strategies/xla.py index a04fc0c1055f8..66624239a714c 100644 --- a/src/lightning/fabric/strategies/xla.py +++ b/src/lightning/fabric/strategies/xla.py @@ -209,8 +209,6 @@ def remove_checkpoint(self, filepath: _PATH) -> None: @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: - # TODO(fabric): Deprecate the name "tpu_spawn" through the connector - strategy_registry.register("tpu_spawn", cls, description=cls.__class__.__name__) strategy_registry.register("xla", cls, description=cls.__class__.__name__) def _set_world_ranks(self) -> None: diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index a4f38d2f66c62..f39ecdb2bdc25 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -89,6 +89,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The `dataloader_idx` argument is now optional for the `on_{validation,test,predict}_batch_{start,end}` hooks. Remove it or default it to 0 if you don't use multiple dataloaders ([#16753](https://github.com/Lightning-AI/lightning/pull/16753)) + +- Renamed `TPUSpawnStrategy` to `XLAStrategy` ([#16781](https://github.com/Lightning-AI/lightning/pull/16781)) + +- Renamed `strategy='tpu_spawn'` to `strategy='xla'` and `strategy='tpu_spawn_debug'` to `strategy='xla_debug'` ([#16781](https://github.com/Lightning-AI/lightning/pull/16781)) + + ### Deprecated - diff --git a/src/lightning/pytorch/strategies/__init__.py b/src/lightning/pytorch/strategies/__init__.py index ed48d873b6160..0cc1dc35b4363 100644 --- a/src/lightning/pytorch/strategies/__init__.py +++ b/src/lightning/pytorch/strategies/__init__.py @@ -23,8 +23,8 @@ from lightning.pytorch.strategies.single_hpu import SingleHPUStrategy # noqa: F401 from lightning.pytorch.strategies.single_tpu import SingleTPUStrategy # noqa: F401 from lightning.pytorch.strategies.strategy import Strategy # noqa: F401 -from lightning.pytorch.strategies.tpu_spawn import TPUSpawnStrategy # noqa: F401 from lightning.pytorch.strategies.utils import _call_register_strategies +from lightning.pytorch.strategies.xla import XLAStrategy # noqa: F401 _STRATEGIES_BASE_MODULE = "lightning.pytorch.strategies" StrategyRegistry = _StrategyRegistry() diff --git a/src/lightning/pytorch/strategies/launchers/xla.py b/src/lightning/pytorch/strategies/launchers/xla.py index 2670c860087eb..692b69f9bfb3c 100644 --- a/src/lightning/pytorch/strategies/launchers/xla.py +++ b/src/lightning/pytorch/strategies/launchers/xla.py @@ -47,7 +47,7 @@ class _XLALauncher(_MultiProcessingLauncher): strategy: A reference to the strategy that is used together with this launcher """ - def __init__(self, strategy: "pl.strategies.TPUSpawnStrategy") -> None: + def __init__(self, strategy: "pl.strategies.XLAStrategy") -> None: if not _XLA_AVAILABLE: raise ModuleNotFoundError(str(_XLA_AVAILABLE)) super().__init__(strategy=strategy, start_method="fork") diff --git a/src/lightning/pytorch/strategies/tpu_spawn.py b/src/lightning/pytorch/strategies/xla.py similarity index 96% rename from src/lightning/pytorch/strategies/tpu_spawn.py rename to src/lightning/pytorch/strategies/xla.py index 757bfa3c99fd3..9bdf53fb032e8 100644 --- a/src/lightning/pytorch/strategies/tpu_spawn.py +++ b/src/lightning/pytorch/strategies/xla.py @@ -43,11 +43,11 @@ MpDeviceLoader = None -class TPUSpawnStrategy(DDPSpawnStrategy): +class XLAStrategy(DDPSpawnStrategy): """Strategy for training multiple TPU devices using the :func:`torch_xla.distributed.xla_multiprocessing.spawn` method.""" - strategy_name = "tpu_spawn" + strategy_name = "xla" def __init__( self, @@ -143,7 +143,7 @@ def is_distributed(self) -> bool: return (xenv.HOST_WORLD_SIZE in os.environ) and self.world_size != 1 def process_dataloader(self, dataloader: Iterable) -> "MpDeviceLoader": - TPUSpawnStrategy._validate_dataloader(dataloader) + XLAStrategy._validate_dataloader(dataloader) from torch_xla.distributed.parallel_loader import MpDeviceLoader if isinstance(dataloader, MpDeviceLoader): @@ -192,7 +192,7 @@ def reduce( invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") if invalid_reduce_op or invalid_reduce_op_str: raise ValueError( - "Currently, the TPUSpawnStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:" + "Currently, the XLAStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:" f" {reduce_op}" ) @@ -293,10 +293,7 @@ def teardown(self) -> None: @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: - strategy_registry.register( - "tpu_spawn_debug", cls, description="TPUSpawn Strategy with `debug` as True", debug=True - ) - + strategy_registry.register("xla_debug", cls, description="XLA strategy with `debug` as True", debug=True) strategy_registry.register( cls.strategy_name, cls, diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index d803ec58c25e5..9b09ac2c29542 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -65,7 +65,7 @@ SingleTPUStrategy, Strategy, StrategyRegistry, - TPUSpawnStrategy, + XLAStrategy, ) from lightning.pytorch.strategies.ddp_spawn import _DDP_FORK_ALIASES from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -442,7 +442,7 @@ def _choose_strategy(self) -> Union[Strategy, str]: return SingleHPUStrategy(device=torch.device("hpu")) if self._accelerator_flag == "tpu": if self._parallel_devices and len(self._parallel_devices) > 1: - return TPUSpawnStrategy.strategy_name + return XLAStrategy.strategy_name else: # TODO: lazy initialized device, then here could be self._strategy_flag = "single_tpu_device" return SingleTPUStrategy(device=self._parallel_devices[0]) # type: ignore @@ -617,10 +617,10 @@ def _lazy_init_strategy(self) -> None: # TODO: should be moved to _check_strategy_and_fallback(). # Current test check precision first, so keep this check here to meet error order if isinstance(self.accelerator, TPUAccelerator) and not isinstance( - self.strategy, (SingleTPUStrategy, TPUSpawnStrategy) + self.strategy, (SingleTPUStrategy, XLAStrategy) ): raise ValueError( - "The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `TPUSpawnStrategy`," + "The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `XLAStrategy`," f" found {self.strategy.__class__.__name__}." ) @@ -644,7 +644,7 @@ def is_distributed(self) -> bool: FSDPStrategy, DDPSpawnStrategy, DeepSpeedStrategy, - TPUSpawnStrategy, + XLAStrategy, HPUParallelStrategy, ) is_distributed = isinstance(self.strategy, distributed_strategy) diff --git a/tests/tests_fabric/strategies/test_registry.py b/tests/tests_fabric/strategies/test_registry.py index 6c636fdf9795b..07aee5ea91f0f 100644 --- a/tests/tests_fabric/strategies/test_registry.py +++ b/tests/tests_fabric/strategies/test_registry.py @@ -55,7 +55,6 @@ def test_available_strategies_in_registry(): "ddp_fork", "ddp_notebook", "single_tpu", - "tpu_spawn", "xla", "dp", } diff --git a/tests/tests_pytorch/accelerators/test_tpu.py b/tests/tests_pytorch/accelerators/test_tpu.py index 22af966eb7554..a2a4389142da8 100644 --- a/tests/tests_pytorch/accelerators/test_tpu.py +++ b/tests/tests_pytorch/accelerators/test_tpu.py @@ -27,7 +27,7 @@ from lightning.pytorch.accelerators.tpu import TPUAccelerator from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.plugins import PrecisionPlugin, TPUPrecisionPlugin, XLACheckpointIO -from lightning.pytorch.strategies import DDPStrategy, TPUSpawnStrategy +from lightning.pytorch.strategies import DDPStrategy, XLAStrategy from lightning.pytorch.utilities import find_shared_parameters from tests_pytorch.helpers.runif import RunIf from tests_pytorch.trainer.optimization.test_manual_optimization import assert_emtpy_grad @@ -94,7 +94,7 @@ def test_accelerator_tpu(accelerator, devices, tpu_available): trainer = Trainer(accelerator=accelerator, devices=devices) assert isinstance(trainer.accelerator, TPUAccelerator) - assert isinstance(trainer.strategy, TPUSpawnStrategy) + assert isinstance(trainer.strategy, XLAStrategy) assert trainer.num_devices == 8 @@ -177,15 +177,15 @@ def test_strategy_choice_tpu_str_ddp_spawn(tpu_available): @RunIf(skip_windows=True) -def test_strategy_choice_tpu_str_tpu_spawn_debug(tpu_available): - trainer = Trainer(strategy="tpu_spawn_debug", accelerator="tpu", devices=8) - assert isinstance(trainer.strategy, TPUSpawnStrategy) +def test_strategy_choice_tpu_str_xla_debug(tpu_available): + trainer = Trainer(strategy="xla_debug", accelerator="tpu", devices=8) + assert isinstance(trainer.strategy, XLAStrategy) @RunIf(tpu=True) def test_strategy_choice_tpu_strategy(): - trainer = Trainer(strategy=TPUSpawnStrategy(), accelerator="tpu", devices=8) - assert isinstance(trainer.strategy, TPUSpawnStrategy) + trainer = Trainer(strategy=XLAStrategy(), accelerator="tpu", devices=8) + assert isinstance(trainer.strategy, XLAStrategy) @RunIf(tpu=True) @@ -237,7 +237,7 @@ def forward(self, x): def test_tpu_invalid_raises(tpu_available): - strategy = TPUSpawnStrategy(accelerator=TPUAccelerator(), precision_plugin=PrecisionPlugin()) + strategy = XLAStrategy(accelerator=TPUAccelerator(), precision_plugin=PrecisionPlugin()) with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"): Trainer(strategy=strategy, devices=8) @@ -248,14 +248,14 @@ def test_tpu_invalid_raises(tpu_available): def test_tpu_invalid_raises_set_precision_with_strategy(tpu_available): accelerator = TPUAccelerator() - strategy = TPUSpawnStrategy(accelerator=accelerator, precision_plugin=PrecisionPlugin()) + strategy = XLAStrategy(accelerator=accelerator, precision_plugin=PrecisionPlugin()) with pytest.raises(ValueError, match="`TPUAccelerator` can only be used with a `TPUPrecisionPlugin`"): Trainer(strategy=strategy, devices=8) accelerator = TPUAccelerator() strategy = DDPStrategy(accelerator=accelerator, precision_plugin=TPUPrecisionPlugin()) with pytest.raises( - ValueError, match="The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `TPUSpawnStrategy" + ValueError, match="The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `XLAStrategy" ): Trainer(strategy=strategy, devices=8) @@ -267,11 +267,11 @@ def test_xla_checkpoint_plugin_being_default(tpu_available): @RunIf(tpu=True) -@patch("lightning.pytorch.strategies.tpu_spawn.TPUSpawnStrategy.root_device") +@patch("lightning.pytorch.strategies.xla.XLAStrategy.root_device") def test_xla_mp_device_dataloader_attribute(_, monkeypatch): dataset = RandomDataset(32, 64) dataloader = DataLoader(dataset) - strategy = TPUSpawnStrategy() + strategy = XLAStrategy() isinstance_return = True import torch_xla.distributed.parallel_loader as parallel_loader diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index 56b7bb8795ff8..d2723e67fa348 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -174,7 +174,7 @@ def mps_count_4(monkeypatch): @pytest.fixture(scope="function") def xla_available(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(lightning.pytorch.accelerators.tpu, "_XLA_AVAILABLE", True) - monkeypatch.setattr(lightning.pytorch.strategies.tpu_spawn, "_XLA_AVAILABLE", True) + monkeypatch.setattr(lightning.pytorch.strategies.xla, "_XLA_AVAILABLE", True) monkeypatch.setattr(lightning.pytorch.strategies.single_tpu, "_XLA_AVAILABLE", True) monkeypatch.setattr(lightning.pytorch.plugins.precision.tpu, "_XLA_AVAILABLE", True) monkeypatch.setattr(lightning.pytorch.strategies.launchers.xla, "_XLA_AVAILABLE", True) diff --git a/tests/tests_pytorch/models/test_tpu.py b/tests/tests_pytorch/models/test_tpu.py index 790f100fe3f58..ceebbca6a7194 100644 --- a/tests/tests_pytorch/models/test_tpu.py +++ b/tests/tests_pytorch/models/test_tpu.py @@ -24,7 +24,7 @@ from lightning.pytorch.accelerators import TPUAccelerator from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset -from lightning.pytorch.strategies import TPUSpawnStrategy +from lightning.pytorch.strategies import XLAStrategy from lightning.pytorch.strategies.launchers.xla import _XLALauncher from lightning.pytorch.trainer.connectors.logger_connector.result import _Sync from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -285,7 +285,7 @@ def wrap_launch_function(fn, strategy, *args, **kwargs): def xla_launch(fn): # TODO: the accelerator should be optional to just launch processes, but this requires lazy initialization accelerator = TPUAccelerator() - strategy = TPUSpawnStrategy(accelerator=accelerator, parallel_devices=list(range(8))) + strategy = XLAStrategy(accelerator=accelerator, parallel_devices=list(range(8))) launcher = _XLALauncher(strategy=strategy) wrapped = partial(wrap_launch_function, fn, strategy) return launcher.launch(wrapped, strategy) @@ -325,7 +325,7 @@ def teardown(self, stage): devices=8, limit_train_batches=0.4, limit_val_batches=0.4, - strategy=TPUSpawnStrategy(debug=True), + strategy=XLAStrategy(debug=True), ) model = DebugModel() @@ -359,6 +359,6 @@ def on_train_start(self): @RunIf(tpu=True) def test_device_type_when_tpu_strategy_passed(tmpdir): - trainer = Trainer(default_root_dir=tmpdir, strategy=TPUSpawnStrategy(), accelerator="tpu", devices=8) - assert isinstance(trainer.strategy, TPUSpawnStrategy) + trainer = Trainer(default_root_dir=tmpdir, strategy=XLAStrategy(), accelerator="tpu", devices=8) + assert isinstance(trainer.strategy, XLAStrategy) assert isinstance(trainer.accelerator, TPUAccelerator) diff --git a/tests/tests_pytorch/strategies/test_registry.py b/tests/tests_pytorch/strategies/test_registry.py index 270fb028fad7f..8882bd441fe1a 100644 --- a/tests/tests_pytorch/strategies/test_registry.py +++ b/tests/tests_pytorch/strategies/test_registry.py @@ -21,7 +21,7 @@ DeepSpeedStrategy, FSDPStrategy, StrategyRegistry, - TPUSpawnStrategy, + XLAStrategy, ) from tests_pytorch.helpers.runif import RunIf @@ -54,15 +54,15 @@ def test_deepspeed_strategy_registry_with_trainer(tmpdir, strategy): @RunIf(skip_windows=True) -def test_tpu_spawn_debug_strategy_registry(xla_available): - strategy = "tpu_spawn_debug" +def test_xla_debug_strategy_registry(xla_available): + strategy = "xla_debug" assert strategy in StrategyRegistry assert StrategyRegistry[strategy]["init_params"] == {"debug": True} - assert StrategyRegistry[strategy]["strategy"] == TPUSpawnStrategy + assert StrategyRegistry[strategy]["strategy"] == XLAStrategy trainer = Trainer(strategy=strategy) - assert isinstance(trainer.strategy, TPUSpawnStrategy) + assert isinstance(trainer.strategy, XLAStrategy) @RunIf(min_torch="1.12") diff --git a/tests/tests_pytorch/strategies/test_tpu_spawn.py b/tests/tests_pytorch/strategies/test_xla.py similarity index 89% rename from tests/tests_pytorch/strategies/test_tpu_spawn.py rename to tests/tests_pytorch/strategies/test_xla.py index ddc3c2c76b577..d7724464a5515 100644 --- a/tests/tests_pytorch/strategies/test_tpu_spawn.py +++ b/tests/tests_pytorch/strategies/test_xla.py @@ -21,7 +21,7 @@ from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset -from lightning.pytorch.strategies import TPUSpawnStrategy +from lightning.pytorch.strategies import XLAStrategy from tests_pytorch.helpers.dataloaders import CustomNotImplementedErrorDataloader from tests_pytorch.helpers.runif import RunIf @@ -45,7 +45,7 @@ def predict_dataloader(self): def test_error_process_iterable_dataloader(xla_available): - strategy = TPUSpawnStrategy(MagicMock()) + strategy = XLAStrategy(MagicMock()) with pytest.raises(TypeError, match="TPUs do not currently support"): strategy.process_dataloader(_loader_no_len) @@ -60,9 +60,9 @@ def on_train_start(self) -> None: @RunIf(tpu=True, standalone=True) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_model_tpu_one_core(): - """Tests if device/debug flag is set correctly when training and after teardown for TPUSpawnStrategy.""" + """Tests if device/debug flag is set correctly when training and after teardown for XLAStrategy.""" model = BoringModelTPU() - trainer = Trainer(accelerator="tpu", devices=1, fast_dev_run=True, strategy=TPUSpawnStrategy(debug=True)) - assert isinstance(trainer.strategy, TPUSpawnStrategy) + trainer = Trainer(accelerator="tpu", devices=1, fast_dev_run=True, strategy=XLAStrategy(debug=True)) + assert isinstance(trainer.strategy, XLAStrategy) trainer.fit(model) assert "PT_XLA_DEBUG" not in os.environ diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 7ddfac98d8abe..f5b6c25200940 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -599,7 +599,7 @@ def test_unsupported_tpu_choice(tpu_available): with pytest.raises(MisconfigurationException, match=r"accelerator='tpu', precision=64\)` is not implemented"): Trainer(accelerator="tpu", precision=64) - # if user didn't set strategy, AcceleratorConnector will choose the TPUSingleStrategy or TPUSpawnStrategy + # if user didn't set strategy, AcceleratorConnector will choose the TPUSingleStrategy or XLAStrategy with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"), pytest.warns( UserWarning, match=r"accelerator='tpu', precision=16\)` but AMP is not supported" ): diff --git a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py index f3e5c6daf3b3d..740af109dad4a 100644 --- a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py +++ b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py @@ -142,7 +142,7 @@ def test_num_stepping_batches_with_tpu_single(): @RunIf(tpu=True) @mock.patch( - "lightning.pytorch.strategies.tpu_spawn.TPUSpawnStrategy.root_device", + "lightning.pytorch.strategies.xla.XLAStrategy.root_device", new_callable=PropertyMock, return_value=torch.device("xla:0"), ) From 1a6331f88fbe24d0cd29d741cbf3cb258ca5ffa2 Mon Sep 17 00:00:00 2001 From: Noha Alon Date: Fri, 17 Feb 2023 09:26:44 +0200 Subject: [PATCH 12/14] fix warning so the user has a clear next step (#16751) --- src/lightning/app/components/serve/auto_scaler.py | 4 +++- src/lightning/app/components/serve/python_server.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/lightning/app/components/serve/auto_scaler.py b/src/lightning/app/components/serve/auto_scaler.py index ed6fdbf6a4a73..da540c37baa3c 100644 --- a/src/lightning/app/components/serve/auto_scaler.py +++ b/src/lightning/app/components/serve/auto_scaler.py @@ -453,7 +453,9 @@ def _get_endpoint_info_page(self) -> Optional["APIAccessFrontend"]: # noqa: F82 try: from lightning_api_access import APIAccessFrontend except ModuleNotFoundError: - logger.warn("APIAccessFrontend not found. Please install lightning-api-access to enable the UI") + logger.warn( + "Some dependencies to run the UI are missing. To resolve, run `pip install lightning-api-access`" + ) return if is_running_in_cloud(): diff --git a/src/lightning/app/components/serve/python_server.py b/src/lightning/app/components/serve/python_server.py index a914135e2cce3..e70335a723ddb 100644 --- a/src/lightning/app/components/serve/python_server.py +++ b/src/lightning/app/components/serve/python_server.py @@ -293,7 +293,9 @@ def configure_layout(self) -> Optional["Frontend"]: try: from lightning_api_access import APIAccessFrontend except ModuleNotFoundError: - logger.warn("APIAccessFrontend not found. Please install lightning-api-access to enable the UI") + logger.warn( + "Some dependencies to run the UI are missing. To resolve, run `pip install lightning-api-access`" + ) return class_name = self.__class__.__name__ From 3a354acc6172e0332e89f412c8ab3ff8da2277e7 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 17 Feb 2023 09:33:17 +0000 Subject: [PATCH 13/14] [App] Reserve APP_SERVER_PORT in cloud port allocation (#16782) Co-authored-by: thomas chaton --- src/lightning/app/utilities/network.py | 2 +- tests/tests_app/utilities/test_network.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/lightning/app/utilities/network.py b/src/lightning/app/utilities/network.py index fb3576b48d22d..db4e5d9f9afdf 100644 --- a/src/lightning/app/utilities/network.py +++ b/src/lightning/app/utilities/network.py @@ -66,7 +66,7 @@ def find_free_network_port() -> int: def _find_free_network_port_cloudspace(): """Finds a free port in the exposed range when running in a cloudspace.""" for port in range( - constants.APP_SERVER_PORT, + constants.APP_SERVER_PORT + 1, # constants.APP_SERVER_PORT is reserved for the app server constants.APP_SERVER_PORT + constants.LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT, ): if port in _reserved_ports: diff --git a/tests/tests_app/utilities/test_network.py b/tests/tests_app/utilities/test_network.py index 1795d5d524966..f8cc25304f0ae 100644 --- a/tests/tests_app/utilities/test_network.py +++ b/tests/tests_app/utilities/test_network.py @@ -2,6 +2,7 @@ import pytest +from lightning.app.core import constants from lightning.app.utilities.network import find_free_network_port, LightningClient @@ -40,6 +41,9 @@ def test_find_free_network_port_cloudspace(_, patch_constants): # Check that all ports are unique assert len(ports) == num_ports + # Shouldn't use the APP_SERVER_PORT + assert constants.APP_SERVER_PORT not in ports + def test_lightning_client_retry_enabled(): From ac5fa03385d8b5d5872f3b610fbca38898ec6fd2 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 17 Feb 2023 11:41:18 +0100 Subject: [PATCH 14/14] Introduce new precision layout in fabric (#16767) --- .../source-pytorch/fabric/api/fabric_args.rst | 14 ++-- .../fabric/fundamentals/launch.rst | 10 ++- .../fabric/fundamentals/precision.rst | 32 +++++---- .../fabric/guide/multi_node/cloud.rst | 2 +- src/lightning/fabric/CHANGELOG.md | 2 + src/lightning/fabric/cli.py | 11 +-- src/lightning/fabric/connector.py | 68 ++++++++++++------- src/lightning/fabric/fabric.py | 6 +- src/lightning/fabric/plugins/precision/amp.py | 19 +++--- .../fabric/plugins/precision/deepspeed.py | 14 ++-- .../fabric/plugins/precision/double.py | 2 +- .../fabric/plugins/precision/fsdp.py | 8 +-- .../fabric/plugins/precision/precision.py | 8 ++- .../fabric/plugins/precision/tpu_bf16.py | 2 +- src/lightning/fabric/strategies/deepspeed.py | 12 ++-- src/lightning/fabric/strategies/fsdp.py | 5 +- .../pytorch/plugins/precision/amp.py | 2 +- .../pytorch/plugins/precision/deepspeed.py | 2 +- .../pytorch/plugins/precision/double.py | 2 +- .../pytorch/plugins/precision/hpu.py | 2 +- .../pytorch/plugins/precision/ipu.py | 2 +- .../pytorch/plugins/precision/tpu_bf16.py | 2 +- src/lightning/pytorch/trainer/trainer.py | 2 +- .../plugins/precision/test_amp.py | 18 ++--- .../plugins/precision/test_amp_integration.py | 8 +-- .../plugins/precision/test_deepspeed.py | 6 +- .../precision/test_deepspeed_integration.py | 2 +- .../precision/test_double_integration.py | 2 +- .../plugins/precision/test_fsdp.py | 4 +- .../strategies/test_deepspeed_integration.py | 14 ++-- .../strategies/test_fsdp_integration.py | 2 +- tests/tests_fabric/test_cli.py | 4 +- tests/tests_fabric/test_connector.py | 60 ++++++++++++---- 33 files changed, 214 insertions(+), 135 deletions(-) diff --git a/docs/source-pytorch/fabric/api/fabric_args.rst b/docs/source-pytorch/fabric/api/fabric_args.rst index 244129549a110..3ee1fe9e9529f 100644 --- a/docs/source-pytorch/fabric/api/fabric_args.rst +++ b/docs/source-pytorch/fabric/api/fabric_args.rst @@ -112,23 +112,27 @@ Learn more about :ref:`distributed multi-node training on clusters `_). +Fabric supports double precision (64 bit), full precision (32 bit), or half-precision (16 bit) floating point operation (including `bfloat16 `_). Half precision, or mixed precision, combines 32 and 16-bit floating points to reduce the memory footprint during model training. +Automatic mixed precision settings are denoted by a ``"-mixed"`` suffix, while settings that only work in the specified precision have a ``"-true"`` suffix. This can result in improved performance, achieving significant speedups on modern GPUs. .. code-block:: python # Default used by the Fabric - fabric = Fabric(precision=32, devices=1) + fabric = Fabric(precision="32-true", devices=1) + + # the same as: + fabric = Fabric(precision="32", devices=1) # 16-bit (mixed) precision - fabric = Fabric(precision=16, devices=1) + fabric = Fabric(precision="16-mixed", devices=1) # 16-bit bfloat precision - fabric = Fabric(precision="bf16", devices=1) + fabric = Fabric(precision="bf16-mixed", devices=1) # 64-bit (double) precision - fabric = Fabric(precision=64, devices=1) + fabric = Fabric(precision="64-true", devices=1) See also: :doc:`../fundamentals/precision` diff --git a/docs/source-pytorch/fabric/fundamentals/launch.rst b/docs/source-pytorch/fabric/fundamentals/launch.rst index 9a49d9b050f4f..a8311e6134c14 100644 --- a/docs/source-pytorch/fabric/fundamentals/launch.rst +++ b/docs/source-pytorch/fabric/fundamentals/launch.rst @@ -68,9 +68,13 @@ This is essentially the same as running ``python path/to/your/script.py``, but i --main-port, --main_port INTEGER The main port to connect to the main machine. - --precision [64|32|16|bf16] Double precision (``64``), full precision - (``32``), half precision (``16``) or - bfloat16 precision (``'bf16'``) + --precision [16-mixed|bf16-mixed|32-true|64-true|64|32|16|bf16] + Double precision (``64-true`` or ``64``), + full precision (``32-true`` or ``64``), half + precision (``16-mixed`` or ``16``) or + bfloat16 precision (``bf16-mixed`` or + ``bf16``) + --help Show this message and exit. diff --git a/docs/source-pytorch/fabric/fundamentals/precision.rst b/docs/source-pytorch/fabric/fundamentals/precision.rst index d3a5b2ab726df..5d24b41ba4e54 100644 --- a/docs/source-pytorch/fabric/fundamentals/precision.rst +++ b/docs/source-pytorch/fabric/fundamentals/precision.rst @@ -24,18 +24,27 @@ This is how you select the precision in Fabric: from lightning.fabric import Fabric # This is the default + fabric = Fabric(precision="32-true") + + # Also FP32 fabric = Fabric(precision=32) - # FP16 mixed precision - fabric = Fabric(precision=16) + # FP32 as well + fabric = Fabric(precision="32") - # Precision values can also be set as a string - fabric = Fabric(precision="16") + # FP16 mixed precision + fabric = Fabric(precision="16-mixed) # BFloat16 precision (Volta GPUs and later) - fabric = Fabric(precision="bf16") + fabric = Fabric(precision="bf16-mixed") # Double precision + fabric = Fabric(precision="64-true") + + # Or + fabric = Fabric(precision="64") + + # Or fabric = Fabric(precision=64) @@ -43,7 +52,7 @@ The same values can also be set through the :doc:`command line interface `__ with the dtype set to ``bfloat16``, with no gradient scaling. @@ -117,7 +123,7 @@ Fabric automatically casts the data type and operations in the ``forward`` of yo .. code-block:: python - fabric = Fabric(precision="bf16") + fabric = Fabric(precision="bf16-mixed") model = ... optimizer = ... diff --git a/docs/source-pytorch/fabric/guide/multi_node/cloud.rst b/docs/source-pytorch/fabric/guide/multi_node/cloud.rst index 117e64c6e3592..f5a7d9b352b0e 100644 --- a/docs/source-pytorch/fabric/guide/multi_node/cloud.rst +++ b/docs/source-pytorch/fabric/guide/multi_node/cloud.rst @@ -50,7 +50,7 @@ Launch multi-node training in the cloud def run(self): # Set up Fabric # The `devices` and `num_nodes` gets set by Lightning automatically - fabric = L.Fabric(strategy="ddp", precision=16) + fabric = L.Fabric(strategy="ddp", precision="16-mixed") # Your training code model = ... diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 148dd7dff4f25..f53eeea7081fe 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -38,6 +38,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Renamed `strategy='tpu_spawn'` to `strategy='xla'` and `strategy='tpu_spawn_debug'` to `strategy='xla_debug'` ([#16781](https://github.com/Lightning-AI/lightning/pull/16781)) +- Changed arguments for precision settings (from [64|32|16|bf16] to ["64-true"|"32-true"|"16-mixed"|"bf16-mixed"]) ([#16767](https://github.com/Lightning-AI/lightning/pull/16767)) + ### Deprecated - diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index 6ade6d5ce1039..4671a75da7577 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -18,8 +18,10 @@ from typing import Any, List, Optional from lightning_utilities.core.imports import RequirementCache +from typing_extensions import get_args from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator +from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS from lightning.fabric.strategies import STRATEGY_REGISTRY from lightning.fabric.utilities.device_parser import _parse_gpu_ids @@ -28,7 +30,6 @@ _CLICK_AVAILABLE = RequirementCache("click") _SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu") -_SUPPORTED_PRECISION = ("64", "32", "16", "bf16") def _get_supported_strategies() -> List[str]: @@ -106,11 +107,11 @@ def _get_supported_strategies() -> List[str]: ) @click.option( "--precision", - type=click.Choice(_SUPPORTED_PRECISION), - default="32", + type=click.Choice(get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_STR_ALIAS)), + default="32-true", help=( - "Double precision (``64``), full precision (``32``), half precision (``16``) or bfloat16 precision" - " (``'bf16'``)" + "Double precision (``64-true`` or ``64``), full precision (``32-true`` or ``64``), " + "half precision (``16-mixed`` or ``16``) or bfloat16 precision (``bf16-mixed`` or ``bf16``)" ), ) @click.argument("script_args", nargs=-1, type=click.UNPROCESSED) diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 2eb29450bc82a..c8276f6ef8459 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -42,7 +42,13 @@ ) from lightning.fabric.plugins.precision.double import DoublePrecision from lightning.fabric.plugins.precision.fsdp import FSDPPrecision -from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT, _PRECISION_INPUT_INT, _PRECISION_INPUT_STR +from lightning.fabric.plugins.precision.precision import ( + _PRECISION_INPUT, + _PRECISION_INPUT_INT, + _PRECISION_INPUT_STR, + _PRECISION_INPUT_STR_ALIAS, + _PRECISION_INPUT_STR_ALIAS_CONVERSION, +) from lightning.fabric.strategies import ( DeepSpeedStrategy, ParallelStrategy, @@ -98,7 +104,7 @@ def __init__( strategy: Optional[Union[str, Strategy]] = None, devices: Optional[Union[List[int], str, int]] = None, num_nodes: int = 1, - precision: _PRECISION_INPUT = 32, + precision: _PRECISION_INPUT = "32-true", plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, ) -> None: @@ -107,7 +113,7 @@ def __init__( strategy = self._argument_from_env("strategy", strategy, default=None) devices = self._argument_from_env("devices", devices, default=None) num_nodes = self._argument_from_env("num_nodes", num_nodes, default=1) - precision = self._argument_from_env("precision", precision, default=32) + precision = self._argument_from_env("precision", precision, default="32-true") # 1. Parsing flags # Get registered strategies, built-in accelerators and precision plugins @@ -119,7 +125,7 @@ def __init__( # For devices: Assign gpus, etc. to the accelerator flag and devices flag self._strategy_flag: Optional[Union[Strategy, str]] = None self._accelerator_flag: Optional[Union[Accelerator, str]] = None - self._precision_input: _PRECISION_INPUT_STR = "32" + self._precision_input: _PRECISION_INPUT_STR = "32-true" self._precision_instance: Optional[Precision] = None self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None self._parallel_devices: List[Union[int, torch.device, str]] = [] @@ -220,10 +226,7 @@ def _check_config_and_set_final_flags( self._accelerator_flag = accelerator - supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) - if precision not in supported_precision: - raise ValueError(f"Precision {repr(precision)} is invalid. Allowed precision values: {supported_precision}") - self._precision_input = cast(_PRECISION_INPUT_STR, str(precision)) + self._precision_input = _convert_precision_to_unified_args(precision) if plugins: plugins_flags_types: Dict[str, int] = Counter() @@ -453,34 +456,34 @@ def _check_and_init_precision(self) -> Precision: return self._precision_instance if isinstance(self.accelerator, TPUAccelerator): - if self._precision_input == "32": + if self._precision_input == "32-true": return TPUPrecision() - elif self._precision_input in ("16", "bf16"): - if self._precision_input == "16": + elif self._precision_input in ("16-mixed", "bf16-mixed"): + if self._precision_input == "16-mixed": rank_zero_warn( - "You passed `Fabric(accelerator='tpu', precision=16)` but AMP" - " is not supported with TPUs. Using `precision='bf16'` instead." + "You passed `Fabric(accelerator='tpu', precision='16-mixed')` but AMP with fp16" + " is not supported with TPUs. Using `precision='bf16-mixed'` instead." ) return TPUBf16Precision() if isinstance(self.strategy, DeepSpeedStrategy): return DeepSpeedPrecision(self._precision_input) # type: ignore - if self._precision_input == "32": + if self._precision_input == "32-true": return Precision() - if self._precision_input == "64": + if self._precision_input == "64-true": return DoublePrecision() - if self._precision_input == "16" and self._accelerator_flag == "cpu": + if self._precision_input == "16-mixed" and self._accelerator_flag == "cpu": rank_zero_warn( - "You passed `Fabric(accelerator='cpu', precision=16)` but AMP is not supported on CPU." - " Using `precision='bf16'` instead." + "You passed `Fabric(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on " + "CPU. Using `precision='bf16-mixed'` instead." ) - self._precision_input = "bf16" + self._precision_input = "bf16-mixed" - if self._precision_input in ("16", "bf16"): + if self._precision_input in ("16-mixed", "bf16-mixed"): rank_zero_info( "Using 16-bit Automatic Mixed Precision (AMP)" - if self._precision_input == "16" + if self._precision_input == "16-mixed" else "Using bfloat16 Automatic Mixed Precision (AMP)" ) device = "cpu" if self._accelerator_flag == "cpu" else "cuda" @@ -494,9 +497,9 @@ def _check_and_init_precision(self) -> Precision: def _validate_precision_choice(self) -> None: """Validate the combination of choices for precision, and accelerator.""" if isinstance(self.accelerator, TPUAccelerator): - if self._precision_input == "64": + if self._precision_input == "64-true": raise NotImplementedError( - "`Fabric(accelerator='tpu', precision=64)` is not implemented." + "`Fabric(accelerator='tpu', precision='64-true')` is not implemented." " Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`" " requesting this feature." ) @@ -561,3 +564,22 @@ def _argument_from_env(name: str, current: Any, default: Any) -> Any: if env_value is None: return current return env_value + + +def _convert_precision_to_unified_args(precision: _PRECISION_INPUT) -> _PRECISION_INPUT_STR: + supported_precision = ( + get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + get_args(_PRECISION_INPUT_STR_ALIAS) + ) + if precision not in supported_precision: + raise ValueError(f"Precision {repr(precision)} is invalid. Allowed precision values: {supported_precision}") + + precision = str(precision) # convert int flags to str here to enable the legacy-conversion below + + if precision in get_args(_PRECISION_INPUT_STR_ALIAS): + if str(precision)[:2] not in ("32", "64"): + rank_zero_warn( + f"{precision} is supported for historical reasons but its usage is discouraged. " + f"Please set your precision to {_PRECISION_INPUT_STR_ALIAS_CONVERSION[precision]} instead!" + ) + precision = _PRECISION_INPUT_STR_ALIAS_CONVERSION[precision] + return cast(_PRECISION_INPUT_STR, precision) diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 7e41844243673..47a3097244258 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -67,8 +67,8 @@ class Fabric: devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``. The value applies per node. num_nodes: Number of GPU nodes for distributed training. - precision: Double precision (``64``), full precision (``32``), half precision (``16``), - or bfloat16 precision (``"bf16"``). + precision: Double precision (``"64-true"``), full precision (``"32"``), half precision AMP (``"16-mixed"``), + or bfloat16 precision AMP (``"bf16-mixed"``). plugins: One or several custom plugins callbacks: A single callback or a list of callbacks. A callback can contain any arbitrary methods that can be invoked through :meth:`~lightning.fabric.fabric.Fabric.call` by the user. @@ -82,7 +82,7 @@ def __init__( strategy: Optional[Union[str, Strategy]] = None, devices: Optional[Union[List[int], str, int]] = None, num_nodes: int = 1, - precision: _PRECISION_INPUT = 32, + precision: _PRECISION_INPUT = "32-true", plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, callbacks: Optional[Union[List[Any], Any]] = None, loggers: Optional[Union[Logger, List[Logger]]] = None, diff --git a/src/lightning/fabric/plugins/precision/amp.py b/src/lightning/fabric/plugins/precision/amp.py index 5fa6752ae4e0c..e7ff4858220d2 100644 --- a/src/lightning/fabric/plugins/precision/amp.py +++ b/src/lightning/fabric/plugins/precision/amp.py @@ -29,21 +29,24 @@ class MixedPrecision(Precision): """Plugin for Automatic Mixed Precision (AMP) training with ``torch.autocast``. Args: - precision: Whether to use ``torch.float16`` (``16``) or ``torch.bfloat16`` (``'bf16'``). + precision: Whether to use ``torch.float16`` (``'16-mixed'``) or ``torch.bfloat16`` (``'bf16-mixed'``). device: The device for ``torch.autocast``. scaler: An optional :class:`torch.cuda.amp.GradScaler` to use. """ def __init__( - self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None + self, + precision: Literal["16-mixed", "bf16-mixed"], + device: str, + scaler: Optional[torch.cuda.amp.GradScaler] = None, ) -> None: - self.precision = cast(Literal["16", "bf16"], str(precision)) - if scaler is None and self.precision == "16": + self.precision = cast(Literal["16-mixed", "bf16-mixed"], str(precision)) + if scaler is None and self.precision == "16-mixed": with _patch_cuda_is_available(): # if possible, we defer CUDA initialization to support strategies that will attempt forks scaler = torch.cuda.amp.GradScaler() - if scaler is not None and self.precision == "bf16": - raise ValueError(f"`precision='bf16'` does not use a scaler, found {scaler}.") + if scaler is not None and self.precision == "bf16-mixed": + raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.") self.device = device self.scaler = scaler @@ -53,7 +56,7 @@ def forward_context(self) -> Generator[None, None, None]: yield def convert_input(self, data: Tensor) -> Tensor: - precision_to_type = {"bf16": torch.bfloat16, "16": torch.float16} + precision_to_type = {"bf16-mixed": torch.bfloat16, "16-mixed": torch.float16} dst_type = precision_to_type[self.precision] return _convert_fp_tensor(data, dst_type) @@ -89,4 +92,4 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def _autocast_context_manager(self) -> torch.autocast: # the dtype could be automatically inferred but we need to manually set it due to a bug upstream # https://github.com/pytorch/pytorch/issues/67233 - return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half) + return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16-mixed" else torch.half) diff --git a/src/lightning/fabric/plugins/precision/deepspeed.py b/src/lightning/fabric/plugins/precision/deepspeed.py index c06a43d8c03c0..44195f823e04f 100644 --- a/src/lightning/fabric/plugins/precision/deepspeed.py +++ b/src/lightning/fabric/plugins/precision/deepspeed.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, cast, Literal, TYPE_CHECKING, Union +from typing import Any, Literal, TYPE_CHECKING import torch from torch import Tensor @@ -27,16 +27,14 @@ if _DEEPSPEED_AVAILABLE: # type: ignore[has-type] import deepspeed -_PRECISION_INPUT_INT = Literal[32, 16] -_PRECISION_INPUT_STR = Literal["32", "16", "bf16"] -_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR] +_PRECISION_INPUT = Literal["32-true", "16-mixed", "bf16-mixed"] class DeepSpeedPrecision(Precision): """Precision plugin for DeepSpeed integration. Args: - precision: Full precision (32), half precision (16) or bfloat16 precision (bf16). + precision: Full precision (32-true), half precision (16-mixed) or bfloat16 precision (bf16-mixed). Raises: ValueError: @@ -44,16 +42,16 @@ class DeepSpeedPrecision(Precision): """ def __init__(self, precision: _PRECISION_INPUT) -> None: - supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + supported_precision = get_args(_PRECISION_INPUT) if precision not in supported_precision: raise ValueError( f"`precision={precision!r})` is not supported in DeepSpeed." f" `precision` must be one of: {supported_precision}." ) - self.precision = cast(_PRECISION_INPUT_STR, str(precision)) + self.precision = precision def convert_input(self, data: Tensor) -> Tensor: - precision_to_type = {"bf16": torch.bfloat16, "16": torch.float16, "32": torch.float32} + precision_to_type = {"bf16-mixed": torch.bfloat16, "16-mixed": torch.float16, "32-true": torch.float32} dst_type = precision_to_type[self.precision] return _convert_fp_tensor(data, dst_type) diff --git a/src/lightning/fabric/plugins/precision/double.py b/src/lightning/fabric/plugins/precision/double.py index 652fe8ede5076..687dec35ed568 100644 --- a/src/lightning/fabric/plugins/precision/double.py +++ b/src/lightning/fabric/plugins/precision/double.py @@ -25,7 +25,7 @@ class DoublePrecision(Precision): """Plugin for training with double (``torch.float64``) precision.""" - precision: Literal["64"] = "64" + precision: Literal["64-true"] = "64-true" def convert_module(self, module: Module) -> Module: return module.double() diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index dc471049df2d0..10de06ab2707a 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -27,7 +27,7 @@ class FSDPPrecision(MixedPrecision): """AMP for Fully Sharded Data Parallel training.""" def __init__( - self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None + self, precision: Literal["16-mixed", "bf16-mixed"], device: str, scaler: Optional["ShardedGradScaler"] = None ) -> None: if not _TORCH_GREATER_EQUAL_1_12: raise NotImplementedError("`FSDPPrecision` is supported from PyTorch v1.12.0 onwards.") @@ -37,16 +37,16 @@ def __init__( super().__init__( precision=precision, device=device, - scaler=(ShardedGradScaler() if scaler is None and str(precision) == "16" else None), + scaler=(ShardedGradScaler() if scaler is None and precision == "16-mixed" else None), ) @property def mixed_precision_config(self) -> "TorchMixedPrecision": from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision - if self.precision == "16": + if self.precision == "16-mixed": dtype = torch.float16 - elif self.precision == "bf16": + elif self.precision == "bf16-mixed": dtype = torch.bfloat16 else: raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.") diff --git a/src/lightning/fabric/plugins/precision/precision.py b/src/lightning/fabric/plugins/precision/precision.py index b609ee78a21e9..e1add043662fe 100644 --- a/src/lightning/fabric/plugins/precision/precision.py +++ b/src/lightning/fabric/plugins/precision/precision.py @@ -23,8 +23,10 @@ from lightning.fabric.utilities.types import _PARAMETERS, Optimizable _PRECISION_INPUT_INT = Literal[64, 32, 16] -_PRECISION_INPUT_STR = Literal["64", "32", "16", "bf16"] -_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR] +_PRECISION_INPUT_STR_ALIAS_CONVERSION = {"64": "64-true", "32": "32-true", "16": "16-mixed", "bf16": "bf16-mixed"} +_PRECISION_INPUT_STR_ALIAS = Literal["64", "32", "16", "bf16"] +_PRECISION_INPUT_STR = Literal["16-mixed", "bf16-mixed", "32-true", "64-true"] +_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS] class Precision: @@ -33,7 +35,7 @@ class Precision: The class attribute precision must be overwritten in child classes. The default value reflects fp32 training. """ - precision: _PRECISION_INPUT_STR = "32" + precision: _PRECISION_INPUT_STR = "32-true" def convert_module(self, module: Module) -> Module: """Convert the module parameters to the precision type this plugin handles. diff --git a/src/lightning/fabric/plugins/precision/tpu_bf16.py b/src/lightning/fabric/plugins/precision/tpu_bf16.py index 79654a9c041a3..d0b9dc5b686e6 100644 --- a/src/lightning/fabric/plugins/precision/tpu_bf16.py +++ b/src/lightning/fabric/plugins/precision/tpu_bf16.py @@ -24,7 +24,7 @@ class TPUBf16Precision(TPUPrecision): """Plugin that enables bfloats on TPUs.""" - precision: Literal["bf16"] = "bf16" + precision: Literal["bf16-mixed"] = "bf16-mixed" def __init__(self) -> None: super().__init__() diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index dab9b7ec5886e..290b2d05c6334 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -105,8 +105,8 @@ def __init__( Arguments: - zero_optimization: Enable ZeRO optimization. This is compatible with either ``precision=16`` or - ``precision="bf16"``. + zero_optimization: Enable ZeRO optimization. This is compatible with either ``precision="16-mixed"`` or + ``precision="bf16-mixed"``. stage: Different stages of the ZeRO Optimizer. 0 is disabled, 1 is optimizer state partitioning, 2 is optimizer+gradient state partitioning, @@ -350,9 +350,9 @@ def module_sharded_context(self) -> Generator[None, None, None]: if self.zero_stage_3: assert self._config_initialized - if self.precision.precision == "16": + if self.precision.precision == "16-mixed": dtype = torch.float16 - elif self.precision.precision == "bf16": + elif self.precision.precision == "bf16-mixed": dtype = torch.bfloat16 else: dtype = torch.float32 @@ -604,7 +604,7 @@ def _format_config(self) -> None: def _format_precision_config(self) -> None: assert isinstance(self.config, dict) - if self.precision.precision == "16": + if self.precision.precision == "16-mixed": if "fp16" not in self.config: # FP16 is a DeepSpeed standalone AMP implementation rank_zero_info("Enabling DeepSpeed FP16.") @@ -616,7 +616,7 @@ def _format_precision_config(self) -> None: "hysteresis": self.hysteresis, "min_loss_scale": self.min_loss_scale, } - elif "bf16" not in self.config and self.precision.precision == "bf16": + elif "bf16" not in self.config and self.precision.precision == "bf16-mixed": rank_zero_info("Enabling DeepSpeed BF16.") self.config["bf16"] = {"enabled": True} diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index ebb27dfb15b49..b465bc9a94126 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -76,8 +76,9 @@ class FSDPStrategy(ParallelStrategy, _Sharded): backward_prefetch: This is an experimental feature that is subject to change in the near future. It allows users to enable two different backward prefetching algorithms to help backward communication and computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. - mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16 - if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later. + mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision="16-mixed"`` or + BF16 if ``precision="bf16-mixed"`` unless a config is passed in. + This is only available in PyTorch 1.12 and later. activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation checkpointing. This is typically your transformer block (including attention + feed-forward). Enabling this can free up a significant amount of memory at the cost of speed since activations in diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index ce984070ae7a5..3d6b894097649 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -36,7 +36,7 @@ class MixedPrecisionPlugin(PrecisionPlugin): def __init__( self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None ) -> None: - self.precision = cast(Literal["16", "bf16"], str(precision)) + self.precision = cast(Literal["16", "bf16"], str(precision)) # type: ignore if scaler is None and self.precision == "16": with _patch_cuda_is_available(): # if possible, we defer CUDA initialization to support strategies that will attempt forks diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py index 20f5748aa444e..8f0845303c8ba 100644 --- a/src/lightning/pytorch/plugins/precision/deepspeed.py +++ b/src/lightning/pytorch/plugins/precision/deepspeed.py @@ -53,7 +53,7 @@ def __init__(self, precision: Literal["32", 32, "16", 16, "bf16"]) -> None: f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported." f" `precision` must be one of: {supported_precision}." ) - self.precision = cast(_PRECISION_INPUT_STR, str(precision)) + self.precision = cast(_PRECISION_INPUT_STR, str(precision)) # type: ignore def backward( # type: ignore[override] self, diff --git a/src/lightning/pytorch/plugins/precision/double.py b/src/lightning/pytorch/plugins/precision/double.py index e008097046637..78785a4c58ca5 100644 --- a/src/lightning/pytorch/plugins/precision/double.py +++ b/src/lightning/pytorch/plugins/precision/double.py @@ -72,7 +72,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: class DoublePrecisionPlugin(PrecisionPlugin): """Plugin for training with double (``torch.float64``) precision.""" - precision: Literal["64"] = "64" + precision: Literal["64"] = "64" # type: ignore def connect( self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any] diff --git a/src/lightning/pytorch/plugins/precision/hpu.py b/src/lightning/pytorch/plugins/precision/hpu.py index 8d805deae1da8..e668285c445c5 100644 --- a/src/lightning/pytorch/plugins/precision/hpu.py +++ b/src/lightning/pytorch/plugins/precision/hpu.py @@ -54,7 +54,7 @@ def __init__( f"`Trainer(accelerator='hpu', precision={precision!r})` is not supported." f" `precision` must be one of: {supported_precision}." ) - self.precision = cast(_PRECISION_INPUT_STR, str(precision)) + self.precision = cast(_PRECISION_INPUT_STR, str(precision)) # type: ignore if self.precision in ("16", "bf16"): hmp.convert( opt_level=opt_level, bf16_file_path=bf16_file_path, fp32_file_path=fp32_file_path, isVerbose=verbose diff --git a/src/lightning/pytorch/plugins/precision/ipu.py b/src/lightning/pytorch/plugins/precision/ipu.py index f82bc07ac2119..104cec0dcfe99 100644 --- a/src/lightning/pytorch/plugins/precision/ipu.py +++ b/src/lightning/pytorch/plugins/precision/ipu.py @@ -47,7 +47,7 @@ def __init__(self, precision: Literal["32", 32, "16", 16]) -> None: f"`Trainer(accelerator='ipu', precision={precision!r})` is not supported." f" `precision` must be one of: {supported_precision}." ) - self.precision = cast(_PRECISION_INPUT_STR, str(precision)) + self.precision = cast(_PRECISION_INPUT_STR, str(precision)) # type: ignore def backward( # type: ignore[override] self, diff --git a/src/lightning/pytorch/plugins/precision/tpu_bf16.py b/src/lightning/pytorch/plugins/precision/tpu_bf16.py index 814af173d0464..aff41d9c92357 100644 --- a/src/lightning/pytorch/plugins/precision/tpu_bf16.py +++ b/src/lightning/pytorch/plugins/precision/tpu_bf16.py @@ -23,7 +23,7 @@ class TPUBf16PrecisionPlugin(TPUPrecisionPlugin): """Plugin that enables bfloats on TPUs.""" - precision: Literal["bf16"] = "bf16" + precision: Literal["bf16"] = "bf16" # type: ignore def connect( self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any] diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 35beb7c57d662..d94d660e4f1e8 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -1226,7 +1226,7 @@ def lr_scheduler_configs(self) -> List[LRSchedulerConfig]: @property def precision(self) -> _PRECISION_INPUT_STR: - return self.strategy.precision_plugin.precision + return self.strategy.precision_plugin.precision # type: ignore @property def scaler(self) -> Optional[Any]: diff --git a/tests/tests_fabric/plugins/precision/test_amp.py b/tests/tests_fabric/plugins/precision/test_amp.py index fca1bd9140968..ae96107a552f4 100644 --- a/tests/tests_fabric/plugins/precision/test_amp.py +++ b/tests/tests_fabric/plugins/precision/test_amp.py @@ -20,21 +20,21 @@ def test_amp_precision_default_scaler(): - precision = MixedPrecision(precision=16, device=Mock()) + precision = MixedPrecision(precision="16-mixed", device=Mock()) assert isinstance(precision.scaler, torch.cuda.amp.GradScaler) def test_amp_precision_scaler_with_bf16(): - with pytest.raises(ValueError, match="`precision='bf16'` does not use a scaler"): - MixedPrecision(precision="bf16", device=Mock(), scaler=Mock()) + with pytest.raises(ValueError, match="`precision='bf16-mixed'` does not use a scaler"): + MixedPrecision(precision="bf16-mixed", device=Mock(), scaler=Mock()) - precision = MixedPrecision(precision="bf16", device=Mock()) + precision = MixedPrecision(precision="bf16-mixed", device=Mock()) assert precision.scaler is None def test_amp_precision_forward_context(): """Test to ensure that the context manager correctly is set to bfloat16 on CPU and CUDA.""" - precision = MixedPrecision(precision=16, device="cuda") + precision = MixedPrecision(precision="16-mixed", device="cuda") assert precision.device == "cuda" assert isinstance(precision.scaler, torch.cuda.amp.GradScaler) assert torch.get_default_dtype() == torch.float32 @@ -42,7 +42,7 @@ def test_amp_precision_forward_context(): # check with str due to a bug upstream: https://github.com/pytorch/pytorch/issues/65786 assert str(torch.get_autocast_gpu_dtype()) in ("torch.float16", "torch.half") - precision = MixedPrecision(precision="bf16", device="cpu") + precision = MixedPrecision(precision="bf16-mixed", device="cpu") assert precision.device == "cpu" assert precision.scaler is None with precision.forward_context(): @@ -56,7 +56,7 @@ def test_amp_precision_forward_context(): def test_amp_precision_backward(): - precision = MixedPrecision(precision="mixed", device="cuda") + precision = MixedPrecision(precision="16-mixed", device="cuda") precision.scaler = Mock() precision.scaler.scale = Mock(side_effect=(lambda x: x)) tensor = Mock() @@ -67,7 +67,7 @@ def test_amp_precision_backward(): def test_amp_precision_optimizer_step_with_scaler(): - precision = MixedPrecision(precision="mixed", device="cuda") + precision = MixedPrecision(precision="16-mixed", device="cuda") precision.scaler = Mock() optimizer = Mock() @@ -77,7 +77,7 @@ def test_amp_precision_optimizer_step_with_scaler(): def test_amp_precision_optimizer_step_without_scaler(): - precision = MixedPrecision(precision="bf16", device="cuda") + precision = MixedPrecision(precision="bf16-mixed", device="cuda") assert precision.scaler is None optimizer = Mock() diff --git a/tests/tests_fabric/plugins/precision/test_amp_integration.py b/tests/tests_fabric/plugins/precision/test_amp_integration.py index 29c809d7c7b7b..133060b45a687 100644 --- a/tests/tests_fabric/plugins/precision/test_amp_integration.py +++ b/tests/tests_fabric/plugins/precision/test_amp_integration.py @@ -61,10 +61,10 @@ def after_backward(self, model): @pytest.mark.parametrize( "accelerator, precision, expected_dtype", [ - ("cpu", 16, torch.bfloat16), - ("cpu", "bf16", torch.bfloat16), - pytest.param("cuda", 16, torch.float16, marks=RunIf(min_cuda_gpus=1)), - pytest.param("cuda", "bf16", torch.bfloat16, marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)), + ("cpu", "16-mixed", torch.bfloat16), + ("cpu", "bf16-mixed", torch.bfloat16), + pytest.param("cuda", "16-mixed", torch.float16, marks=RunIf(min_cuda_gpus=1)), + pytest.param("cuda", "bf16-mixed", torch.bfloat16, marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)), ], ) def test_amp(accelerator, precision, expected_dtype): diff --git a/tests/tests_fabric/plugins/precision/test_deepspeed.py b/tests/tests_fabric/plugins/precision/test_deepspeed.py index 5eed7c0a4d933..6f8316e0d8f9e 100644 --- a/tests/tests_fabric/plugins/precision/test_deepspeed.py +++ b/tests/tests_fabric/plugins/precision/test_deepspeed.py @@ -23,11 +23,11 @@ def test_invalid_precision_with_deepspeed_precision(): with pytest.raises(ValueError, match="is not supported in DeepSpeed. `precision` must be one of"): - DeepSpeedPrecision(precision=64) + DeepSpeedPrecision(precision="64-true") def test_deepspeed_precision_backward(): - precision = DeepSpeedPrecision(precision=32) + precision = DeepSpeedPrecision(precision="32-true") tensor = Mock() model = Mock() precision.backward(tensor, model, "positional-arg", keyword="arg") @@ -45,7 +45,7 @@ def test_deepspeed_engine_is_steppable(engine): def test_deepspeed_precision_optimizer_step(): - precision = DeepSpeedPrecision(precision=32) + precision = DeepSpeedPrecision(precision="32-true") optimizer = model = Mock() precision.optimizer_step(optimizer, lr_kwargs=dict()) model.step.assert_called_once_with(lr_kwargs=dict()) diff --git a/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py b/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py index 5dfc670603db0..544649c163166 100644 --- a/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py +++ b/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py @@ -22,7 +22,7 @@ @RunIf(deepspeed=True) -@pytest.mark.parametrize("precision", ["bf16", 16, 32]) +@pytest.mark.parametrize("precision", ["bf16-mixed", "16-mixed", "32-true"]) @mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False) def test_deepspeed_precision_choice(_, precision): """Test to ensure precision plugin is correctly chosen. diff --git a/tests/tests_fabric/plugins/precision/test_double_integration.py b/tests/tests_fabric/plugins/precision/test_double_integration.py index 7af8ac6e28416..012d1f39623b0 100644 --- a/tests/tests_fabric/plugins/precision/test_double_integration.py +++ b/tests/tests_fabric/plugins/precision/test_double_integration.py @@ -50,5 +50,5 @@ def after_backward(self, model): def test_double_precision(): - fabric = DoublePrecisionBoringFabric(precision=64) + fabric = DoublePrecisionBoringFabric(precision="64-true") fabric.run() diff --git a/tests/tests_fabric/plugins/precision/test_fsdp.py b/tests/tests_fabric/plugins/precision/test_fsdp.py index cee1afe35941e..fb121dc191b7b 100644 --- a/tests/tests_fabric/plugins/precision/test_fsdp.py +++ b/tests/tests_fabric/plugins/precision/test_fsdp.py @@ -23,11 +23,11 @@ @mock.patch("lightning.fabric.plugins.precision.fsdp._TORCH_GREATER_EQUAL_1_12", False) def test_fsdp_precision_support(*_): with pytest.raises(NotImplementedError, match="`FSDPPrecision` is supported from PyTorch v1.12.0"): - FSDPPrecision(precision=16, device="cuda") + FSDPPrecision(precision="16-mixed", device="cuda") @RunIf(min_torch="1.12", min_cuda_gpus=1) -@pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)]) +@pytest.mark.parametrize("precision, expected", [("16-mixed", torch.float16), ("bf16-mixed", torch.bfloat16)]) def test_fsdp_precision_config(precision, expected): plugin = FSDPPrecision(precision=precision, device="cuda") config = plugin.mixed_precision_config diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py index a068d19cbc31d..6e4232459ac3b 100644 --- a/tests/tests_fabric/strategies/test_deepspeed_integration.py +++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py @@ -151,7 +151,7 @@ def run(self): strategy=DeepSpeedStrategy(), accelerator="cuda", devices=1, - precision=16, + precision="16-mixed", ) fabric.run() @@ -175,7 +175,7 @@ def run(self): ) fabric = RunFabric( strategy=strategy, - precision=16, + precision="16-mixed", accelerator="cuda", devices=1, ) @@ -212,7 +212,7 @@ def run(self): ) fabric = RunFabric( strategy=strategy, - precision=16, + precision="16-mixed", accelerator="cuda", devices=1, ) @@ -247,7 +247,7 @@ def test_deepspeed_multigpu_stage_3(): strategy=DeepSpeedStrategy(stage=3), accelerator="cuda", devices=2, - precision=16, + precision="16-mixed", ) fabric.run() @@ -318,9 +318,9 @@ def step(self, model, batch): assert model.layer.weight.dtype == torch.bfloat16 return super().step(model, batch) - fabric = RunFabric(accelerator="cuda", devices=2, strategy="deepspeed_stage_3", precision="bf16") + fabric = RunFabric(accelerator="cuda", devices=2, strategy="deepspeed_stage_3", precision="bf16-mixed") assert isinstance(fabric._strategy.precision, DeepSpeedPrecision) - assert fabric._strategy.precision.precision == "bf16" + assert fabric._strategy.precision.precision == "bf16-mixed" assert fabric._strategy.config["zero_optimization"]["stage"] == 3 fabric.run() @@ -361,7 +361,7 @@ def test_deepspeed_save_load_checkpoint_zero_3(stage, tmp_path): """Test that DeepSpeed stage 1, 2, and 3 model checkpoints can be saved and loaded successfully.""" from deepspeed import DeepSpeedEngine - fabric = Fabric(accelerator="cuda", devices=2, strategy=DeepSpeedStrategy(stage=stage), precision="bf16") + fabric = Fabric(accelerator="cuda", devices=2, strategy=DeepSpeedStrategy(stage=stage), precision="bf16-mixed") fabric.launch() checkpoint_path = fabric.broadcast(tmp_path / "deepspeed-checkpoint") diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index c80daee9723cc..c2040a5c4eab2 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -86,7 +86,7 @@ def _custom_auto_wrap_policy(module, recurse, unwrapped_params: int, min_num_par @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.13") -@pytest.mark.parametrize("precision", (16, pytest.param("bf16", marks=RunIf(bf16_cuda=True)))) +@pytest.mark.parametrize("precision", ("16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)))) @pytest.mark.parametrize("manual_wrapping", [True, False]) def test_fsdp_train_save_load(manual_wrapping, precision): """Test FSDP training, saving and loading with different wrapping and precision settings.""" diff --git a/tests/tests_fabric/test_cli.py b/tests/tests_fabric/test_cli.py index a10fd11c41eeb..051df16528540 100644 --- a/tests/tests_fabric/test_cli.py +++ b/tests/tests_fabric/test_cli.py @@ -43,7 +43,7 @@ def test_cli_env_vars_defaults(monkeypatch, fake_script): assert "LT_STRATEGY" not in os.environ assert os.environ["LT_DEVICES"] == "1" assert os.environ["LT_NUM_NODES"] == "1" - assert os.environ["LT_PRECISION"] == "32" + assert os.environ["LT_PRECISION"] == "32-true" @pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", pytest.param("mps", marks=RunIf(mps=True))]) @@ -120,7 +120,7 @@ def test_cli_env_vars_num_nodes(num_nodes, monkeypatch, fake_script): assert os.environ["LT_NUM_NODES"] == num_nodes -@pytest.mark.parametrize("precision", ["64", "32", "16", "bf16"]) +@pytest.mark.parametrize("precision", ["64-true", "64", "32-true", "32", "16-mixed", "bf16-mixed"]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_cli_env_vars_precision(precision, monkeypatch, fake_script): monkeypatch.setattr(torch.distributed, "run", Mock()) diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 3d9f56b5076b7..8296e2426f4db 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -19,6 +19,7 @@ import pytest import torch import torch.distributed +from lightning_utilities.test.warning import no_warning_call from tests_fabric.helpers.runif import RunIf import lightning.fabric @@ -438,6 +439,38 @@ def test_validate_precision_type(precision): _Connector(precision=precision) +@pytest.mark.parametrize( + "precision,expected_precision,should_warn", + [ + (16, "16-mixed", True), + ("16", "16-mixed", True), + ("16-mixed", "16-mixed", False), + ("bf16", "bf16-mixed", True), + ("bf16-mixed", "bf16-mixed", False), + (32, "32-true", False), + ("32", "32-true", False), + ("32-true", "32-true", False), + (64, "64-true", False), + ("64", "64-true", False), + ("64-true", "64-true", False), + ], +) +# mock cuda as available to not be limited by dtype and accelerator compatibility - this is tested elsewhere +@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1) +@mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False) +def test_precision_conversion(patch1, patch2, precision, expected_precision, should_warn): + warn_context = pytest.warns if should_warn else no_warning_call + with warn_context( + UserWarning, + match=( + f"{precision} is supported for historical reasons but its usage is discouraged. " + f"Please set your precision to {expected_precision} instead!" + ), + ): + connector = _Connector(precision=precision, accelerator="cuda") + assert connector._precision_input == expected_precision + + def test_multi_device_default_strategy(): """The default strategy when multiple devices are selected is "ddp" with the subprocess launcher.""" connector = _Connector(strategy=None, accelerator="cpu", devices=2) @@ -632,14 +665,14 @@ def test_strategy_choice_ddp_cpu_slurm(strategy): @mock.patch.dict(os.environ, {}, clear=True) @mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False) def test_unsupported_tpu_choice(_, tpu_available): - with pytest.raises(NotImplementedError, match=r"accelerator='tpu', precision=64\)` is not implemented"): - _Connector(accelerator="tpu", precision=64) + with pytest.raises(NotImplementedError, match=r"accelerator='tpu', precision='64-true'\)` is not implemented"): + _Connector(accelerator="tpu", precision="64-true") # if user didn't set strategy, _Connector will choose the TPUSingleStrategy or XLAStrategy with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"), pytest.warns( - UserWarning, match=r"accelerator='tpu', precision=16\)` but AMP is not supported" + UserWarning, match=r"accelerator='tpu', precision='16-mixed'\)` but AMP with fp16 is not supported" ): - _Connector(accelerator="tpu", precision=16, strategy="ddp") + _Connector(accelerator="tpu", precision="16-mixed", strategy="ddp") # wrong precision plugin type strategy = XLAStrategy(accelerator=TPUAccelerator(), precision=Precision()) @@ -760,8 +793,11 @@ def test_ddp_fork_on_unsupported_platform(_, __, strategy): def test_precision_selection_16_on_cpu_warns(): - with pytest.warns(UserWarning, match=r"precision=16\)` but AMP is not supported on CPU. Using `precision='bf16"): - _Connector(precision=16) + with pytest.warns( + UserWarning, + match=r"precision='16-mixed'\)` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'", + ): + _Connector(precision="16-mixed") class MyAMP(MixedPrecision): @@ -777,9 +813,9 @@ class MyAMP(MixedPrecision): def test_precision_selection_amp_ddp(strategy, devices, is_custom_plugin, plugin_cls): plugin = None if is_custom_plugin: - plugin = plugin_cls(16, "cpu") + plugin = plugin_cls("16-mixed", "cpu") connector = _Connector( - precision=16, + precision="16-mixed", devices=devices, strategy=strategy, plugins=plugin, @@ -794,7 +830,7 @@ def test_strategy_str_passed_being_case_insensitive(_, strategy, strategy_cls): assert isinstance(connector.strategy, strategy_cls) -@pytest.mark.parametrize("precision", ["64", "32", "16", "bf16"]) +@pytest.mark.parametrize("precision", ["64-true", "32-true", "16-mixed", "bf16-mixed"]) @mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1) def test_precision_from_environment(_, precision): """Test that the precision input can be set through the environment variable.""" @@ -856,9 +892,9 @@ def test_arguments_from_environment_collision(): with pytest.raises(ValueError, match="`Fabric\\(num_nodes=2, ...\\)` but .* `--num_nodes=3`"): _Connector(num_nodes=2) - with mock.patch.dict(os.environ, {"LT_PRECISION": "16"}): - with pytest.raises(ValueError, match="`Fabric\\(precision=64, ...\\)` but .* `--precision=16`"): - _Connector(precision=64) + with mock.patch.dict(os.environ, {"LT_PRECISION": "16-mixed"}): + with pytest.raises(ValueError, match="`Fabric\\(precision='64-true', ...\\)` but .* `--precision=16-mixed`"): + _Connector(precision="64-true") @RunIf(min_torch="1.12")