diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index d2bd534e33776..424b82ce532ce 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -98,18 +98,9 @@ RUN \ pip install -r requirements/pytorch/base.txt --no-cache-dir --find-links https://download.pytorch.org/whl/cu${CUDA_VERSION_MM}/torch_stable.html && \ rm assistant.py -RUN \ - # install ColossalAI - # TODO: 1.13 wheels are not released, remove skip once they are - if [[ $PYTORCH_VERSION != "1.13" ]]; then \ - pip install "colossalai==0.2.4"; \ - python -c "import colossalai; print(colossalai.__version__)" ; \ - fi RUN \ # install rest of strategies - # remove colossalai from requirements since they are installed separately - python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'colossalai' not in line] ; open(fname, 'w').writelines(lines)" ; \ cat requirements/pytorch/strategies.txt && \ pip install -r requirements/pytorch/devel.txt -r requirements/pytorch/strategies.txt --no-cache-dir --find-links https://download.pytorch.org/whl/cu${CUDA_VERSION_MM}/torch_stable.html diff --git a/dockers/nvidia/Dockerfile b/dockers/nvidia/Dockerfile index 9bb97e92af04e..cb76595f3eac7 100644 --- a/dockers/nvidia/Dockerfile +++ b/dockers/nvidia/Dockerfile @@ -43,8 +43,6 @@ RUN \ # Installations \ pip install "Pillow>=8.2, !=8.3.0" "cryptography>=3.4" "py>=1.10" --no-cache-dir && \ - # remove colossalai from requirements since they are installed separately - python -c "fname = 'lightning/requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'colossalai' not in line] ; open(fname, 'w').writelines(lines)" ; \ PACKAGE_NAME=pytorch pip install './lightning[extra,loggers,strategies]' --no-cache-dir && \ rm -rf lightning && \ pip list diff --git a/docs/source-pytorch/accelerators/tpu_faq.rst b/docs/source-pytorch/accelerators/tpu_faq.rst index f38f0a865b4cd..de4cd315e4cdb 100644 --- a/docs/source-pytorch/accelerators/tpu_faq.rst +++ b/docs/source-pytorch/accelerators/tpu_faq.rst @@ -61,7 +61,7 @@ How to resolve the replication issue? .format(len(local_devices), len(kind_devices))) RuntimeError: Cannot replicate if number of devices (1) is different from 8 -This error is raised when the XLA device is called outside the spawn process. Internally in `TPUSpawn` Strategy for training on multiple tpu cores, we use XLA's `xmp.spawn`. +This error is raised when the XLA device is called outside the spawn process. Internally in the XLA-Strategy for training on multiple tpu cores, we use XLA's `xmp.spawn`. Don't use ``xm.xla_device()`` while working on Lightning + TPUs! ---- @@ -91,7 +91,7 @@ How to setup the debug mode for Training on TPUs? import pytorch_lightning as pl my_model = MyLightningModule() - trainer = pl.Trainer(accelerator="tpu", devices=8, strategy="tpu_spawn_debug") + trainer = pl.Trainer(accelerator="tpu", devices=8, strategy="xla_debug") trainer.fit(my_model) Example Metrics report: @@ -108,7 +108,7 @@ Example Metrics report: A lot of PyTorch operations aren't lowered to XLA, which could lead to significant slowdown of the training process. These operations are moved to the CPU memory and evaluated, and then the results are transferred back to the XLA device(s). -By using the `tpu_spawn_debug` Strategy, users could create a metrics report to diagnose issues. +By using the `xla_debug` Strategy, users could create a metrics report to diagnose issues. The report includes things like (`XLA Reference `_): diff --git a/docs/source-pytorch/advanced/model_parallel.rst b/docs/source-pytorch/advanced/model_parallel.rst index 9b3030f02ec8c..6603eae0da6c9 100644 --- a/docs/source-pytorch/advanced/model_parallel.rst +++ b/docs/source-pytorch/advanced/model_parallel.rst @@ -37,7 +37,7 @@ This means we cannot sacrifice throughput as much as if we were fine-tuning, bec Overall: * When **fine-tuning** a model, use advanced memory efficient strategies such as :ref:`fully-sharded-training`, :ref:`deepspeed-zero-stage-3` or :ref:`deepspeed-zero-stage-3-offload`, allowing you to fine-tune larger models if you are limited on compute -* When **pre-training** a model, use simpler optimizations such :ref:`sharded-training` or :ref:`deepspeed-zero-stage-2`, scaling the number of GPUs to reach larger parameter sizes +* When **pre-training** a model, use simpler optimizations such as :ref:`deepspeed-zero-stage-2`, scaling the number of GPUs to reach larger parameter sizes * For both fine-tuning and pre-training, use :ref:`deepspeed-activation-checkpointing` as the throughput degradation is not significant For example when using 128 GPUs, you can **pre-train** large 10 to 20 Billion parameter models using :ref:`deepspeed-zero-stage-2` without having to take a performance hit with more advanced optimized multi-gpu strategy. @@ -52,133 +52,17 @@ Sharding techniques help when model sizes are fairly large; roughly 500M+ parame * When your model is small (ResNet50 of around 80M Parameters), unless you are using unusually large batch sizes or inputs. * Due to high distributed communication between devices, if running on a slow network/interconnect, the training might be much slower than expected and then it's up to you to determince the tradeoff here. ----------- - -.. _colossalai: - -*********** -Colossal-AI -*********** - -:class:`~pytorch_lightning.strategies.colossalai.ColossalAIStrategy` implements ZeRO-DP with chunk-based memory management. -With this chunk mechanism, really large models can be trained with a small number of GPUs. -It supports larger trainable model size and batch size than usual heterogeneous training by reducing CUDA memory fragments and CPU memory consumption. -Also, it speeds up this kind of heterogeneous training by fully utilizing all kinds of resources. - -When enabling chunk mechanism, a set of consecutive parameters are stored in a chunk, and then the chunk is sharded across different processes. -This can reduce communication and data transmission frequency and fully utilize communication and PCI-E bandwidth, which makes training faster. - -Unlike traditional implementations, which adopt static memory partition, we implemented a dynamic heterogeneous memory management system named Gemini. -During the first training step, the warmup phase will sample the maximum non-model data memory (memory usage expect parameters, gradients, and optimizer states). -In later training, it will use the collected memory usage information to evict chunks dynamically. -Gemini allows you to fit much larger models with limited GPU memory. - -According to our benchmark results, we can train models with up to 24 billion parameters in 1 GPU. -You can install colossalai by consulting `how to download colossalai `_. -Then, run this benchmark in `Colossalai-PL/gpt `_. - -Here is an example showing how to use ColossalAI: - -.. code-block:: python - - from colossalai.nn.optimizer import HybridAdam - - - class MyBert(LightningModule): - ... - - def configure_sharded_model(self) -> None: - # create your model here - self.model = BertForSequenceClassification.from_pretrained("bert-base-uncased") - - def configure_optimizers(self): - # use the specified optimizer - optimizer = HybridAdam(self.model.parameters(), self.lr) - - ... - - - model = MyBert() - trainer = Trainer(accelerator="gpu", devices=1, precision=16, strategy="colossalai") - trainer.fit(model) - -You can find more examples in the `Colossalai-PL `_ repository. - -.. note:: - - * The only accelerator which ColossalAI supports is ``"gpu"``. But CPU resources will be used when the placement policy is set to "auto" or "cpu". - * The only precision which ColossalAI allows is 16 (FP16). +Cutting-edge and Experimental Strategies +======================================== - * It only supports a single optimizer, which must be ``colossalai.nn.optimizer.CPUAdam`` or ``colossalai.nn.optimizer. - HybridAdam`` now. You can set ``adamw_mode`` to False to use normal Adam. Noticing that ``HybridAdam`` is highly optimized, it uses fused CUDA kernel and parallel CPU kernel. - It is recomended to use ``HybridAdam``, since it updates parameters in GPU and CPU both. +Cutting-edge Lightning strategies are being developed by third-parties outside of Lightning. +If you want to be the first to try the latest and greatest experimental features for model-parallel training, check out the :doc:`Colossal-AI Strategy <./third_party/colossalai>` integration. - * Your model must be created using the :meth:`~pytorch_lightning.core.module.LightningModule.configure_sharded_model` method. - - * ``ColossalaiStrategy`` doesn't support gradient accumulation as of now. - -.. _colossal_placement_policy: - -Placement Policy -================ - -Placement policies can help users fully exploit their GPU-CPU heterogeneous memory space for better training efficiency. -There are three options for the placement policy. -They are "cpu", "cuda" and "auto" respectively. - -When the placement policy is set to "cpu", all participated parameters will be offloaded into CPU memory immediately at the end of every auto-grad operation. -In this way, "cpu" placement policy uses the least CUDA memory. -It is the best choice for users who want to exceptionally enlarge their model size or training batch size. - -When using "cuda" option, all parameters are placed in the CUDA memory, no CPU resources will be used during the training. -It is for users who get plenty of CUDA memory. - -The third option, "auto", enables Gemini. -It monitors the consumption of CUDA memory during the warmup phase and collects CUDA memory usage of all auto-grad operations. -In later training steps, Gemini automatically manages the data transmission between GPU and CPU according to collected CUDA memory usage information. -It is the fastest option when CUDA memory is enough. - -Here's an example of changing the placement policy to "cpu". - -.. code-block:: python - - from pytorch_lightning.strategies import ColossalAIStrategy - - model = MyModel() - my_strategy = ColossalAIStrategy(placement_policy="cpu") - trainer = Trainer(accelerator="gpu", devices=4, precision=16, strategy=my_strategy) - trainer.fit(model) - -.. _sharded-training: - -**************** -Sharded Training -**************** - -The technique can be found within `DeepSpeed ZeRO `_ and -`ZeRO-2 `_, -however the implementation is built from the ground up to be PyTorch compatible and standalone. -Sharded Training allows you to maintain GPU scaling efficiency, whilst reducing memory overhead drastically. In short, expect near-normal linear scaling (if your network allows), and significantly reduced memory usage when training large models. - -Sharded Training still utilizes Data Parallel Training under the hood, except optimizer states and gradients are sharded across GPUs. -This means the memory overhead per GPU is lower, as each GPU only has to maintain a partition of your optimizer state and gradients. - -The benefits vary by model and parameter sizes, but we've recorded up to a 63% memory reduction per GPU allowing us to double our model sizes. Because of efficient communication, -these benefits in multi-GPU setups are almost free and throughput scales well with multi-node setups. - -It is highly recommended to use Sharded Training in multi-GPU environments where memory is limited, or where training larger models are beneficial (500M+ parameter models). -A technical note: as batch size scales, storing activations for the backwards pass becomes the bottleneck in training. As a result, sharding optimizer state and gradients becomes less impactful. - -.. code-block:: python - - # train using Sharded DDP - trainer = Trainer(strategy="ddp_sharded") - -Internally we re-initialize your optimizers and shard them across your machines and processes. We handle all communication using PyTorch distributed, so no code changes are required. ---- + .. _fully-sharded-training: ********************** diff --git a/docs/source-pytorch/advanced/strategy_registry.rst b/docs/source-pytorch/advanced/strategy_registry.rst index 27bab6ea49df4..914db517eb121 100644 --- a/docs/source-pytorch/advanced/strategy_registry.rst +++ b/docs/source-pytorch/advanced/strategy_registry.rst @@ -18,7 +18,7 @@ It also returns the optional description and parameters for initialising the Str trainer = Trainer(strategy="deepspeed_stage_3_offload", accelerator="gpu", devices=3) # Training with the TPU Spawn Strategy with `debug` as True - trainer = Trainer(strategy="tpu_spawn_debug", accelerator="tpu", devices=8) + trainer = Trainer(strategy="xla_debug", accelerator="tpu", devices=8) Additionally, you can pass your custom registered training strategies to the ``strategy`` argument. diff --git a/docs/source-pytorch/advanced/third_party/colossalai.rst b/docs/source-pytorch/advanced/third_party/colossalai.rst new file mode 100644 index 0000000000000..5223bdc0ad60d --- /dev/null +++ b/docs/source-pytorch/advanced/third_party/colossalai.rst @@ -0,0 +1,92 @@ +:orphan: + +########### +Colossal-AI +########### + + +The Colossal-AI strategy implements ZeRO-DP with chunk-based memory management. +With this chunk mechanism, really large models can be trained with a small number of GPUs. +It supports larger trainable model size and batch size than usual heterogeneous training by reducing CUDA memory fragments and CPU memory consumption. +Also, it speeds up this kind of heterogeneous training by fully utilizing all kinds of resources. + +When enabling chunk mechanism, a set of consecutive parameters are stored in a chunk, and then the chunk is sharded across different processes. +This can reduce communication and data transmission frequency and fully utilize communication and PCI-E bandwidth, which makes training faster. + +Unlike traditional implementations, which adopt static memory partition, we implemented a dynamic heterogeneous memory management system named Gemini. +During the first training step, the warmup phase will sample the maximum non-model data memory (memory usage expect parameters, gradients, and optimizer states). +In later training, it will use the collected memory usage information to evict chunks dynamically. +Gemini allows you to fit much larger models with limited GPU memory. + +According to our benchmark results, we can train models with up to 24 billion parameters in 1 GPU. + +You can install the Colossal-AI integration by running + +.. code-block:: bash + + pip install lightning-colossalai + +This will install both the `colossalai `_ package as well as the ``ColossalAIStrategy`` for the Lightning Trainer: + +.. code-block:: python + + trainer = Trainer(strategy="colossalai", precision=16, devices=...) + + +You can tune several settings by instantiating the strategy objects and pass options in: + +.. code-block:: python + + from lightning_colossalai import ColossalAIStrategy + + strategy = ColossalAIStrategy(...) + trainer = Trainer(strategy=strategy, precision=16, devices=...) + + +See a full example of a benchmark with the a `GPT-2 model `_ of up to 24 billion parameters + +.. note:: + + * The only accelerator which ColossalAI supports is ``"gpu"``. But CPU resources will be used when the placement policy is set to "auto" or "cpu". + + * The only precision which ColossalAI allows is 16-bit mixed precision (FP16). + + * It only supports a single optimizer, which must be ``colossalai.nn.optimizer.CPUAdam`` or ``colossalai.nn.optimizer. + HybridAdam`` now. You can set ``adamw_mode`` to False to use normal Adam. Noticing that ``HybridAdam`` is highly optimized, it uses fused CUDA kernel and parallel CPU kernel. + It is recomended to use ``HybridAdam``, since it updates parameters in GPU and CPU both. + + * Your model must be created using the :meth:`~pytorch_lightning.core.module.LightningModule.configure_sharded_model` method. + + * ``ColossalaiStrategy`` doesn't support gradient accumulation as of now. + +.. _colossal_placement_policy: + +Placement Policy +================ + +Placement policies can help users fully exploit their GPU-CPU heterogeneous memory space for better training efficiency. +There are three options for the placement policy. +They are "cpu", "cuda" and "auto" respectively. + +When the placement policy is set to "cpu", all participated parameters will be offloaded into CPU memory immediately at the end of every auto-grad operation. +In this way, "cpu" placement policy uses the least CUDA memory. +It is the best choice for users who want to exceptionally enlarge their model size or training batch size. + +When using "cuda" option, all parameters are placed in the CUDA memory, no CPU resources will be used during the training. +It is for users who get plenty of CUDA memory. + +The third option, "auto", enables Gemini. +It monitors the consumption of CUDA memory during the warmup phase and collects CUDA memory usage of all auto-grad operations. +In later training steps, Gemini automatically manages the data transmission between GPU and CPU according to collected CUDA memory usage information. +It is the fastest option when CUDA memory is enough. + +Here's an example of changing the placement policy to "cpu". + +.. code-block:: python + + from lightning_colossalai import ColossalAIStrategy + + model = MyModel() + my_strategy = ColossalAIStrategy(placement_policy="cpu") + trainer = Trainer(accelerator="gpu", devices=4, precision=16, strategy=my_strategy) + trainer.fit(model) diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index 7e4eb3ca8863f..8187a74ff49fd 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -222,7 +222,7 @@ strategies SingleHPUStrategy SingleTPUStrategy Strategy - TPUSpawnStrategy + XLAStrategy tuner ----- diff --git a/docs/source-pytorch/extensions/strategy.rst b/docs/source-pytorch/extensions/strategy.rst index 429131ef03944..6b7474204e6bd 100644 --- a/docs/source-pytorch/extensions/strategy.rst +++ b/docs/source-pytorch/extensions/strategy.rst @@ -23,7 +23,7 @@ plugin and other optional plugins such as the :ref:`ClusterEnvironment `_ itself). ----------- +---- ***************************** Selecting a Built-in Strategy @@ -69,9 +69,6 @@ The below table lists all relevant strategies available in Lightning with their * - Name - Class - Description - * - colossalai - - :class:`~pytorch_lightning.strategies.ColossalAIStrategy` - - Colossal-AI provides a collection of parallel components for you. It aims to support you to write your distributed deep learning models just like how you write your model on your laptop. `Learn more. `__ * - fsdp - :class:`~pytorch_lightning.strategies.FSDPStrategy` - Strategy for Fully Sharded Data Parallel training. :ref:`Learn more. ` @@ -93,8 +90,8 @@ The below table lists all relevant strategies available in Lightning with their * - ipu_strategy - :class:`~pytorch_lightning.strategies.IPUStrategy` - Plugin for training on IPU devices. :doc:`Learn more. <../accelerators/ipu>` - * - tpu_spawn - - :class:`~pytorch_lightning.strategies.TPUSpawnStrategy` + * - xla + - :class:`~pytorch_lightning.strategies.XLAStrategy` - Strategy for training on multiple TPU devices using the :func:`torch_xla.distributed.xla_multiprocessing.spawn` method. :doc:`Learn more. <../accelerators/tpu>` * - single_tpu - :class:`~pytorch_lightning.strategies.SingleTPUStrategy` @@ -102,6 +99,28 @@ The below table lists all relevant strategies available in Lightning with their ---- + +********************** +Third-party Strategies +********************** + +There are powerful third-party strategies that integrate well with Lightning but aren't maintained as part of the ``lightning`` package. + +.. list-table:: List of third-party strategy implementations + :widths: 20 20 20 + :header-rows: 1 + + * - Name + - Package + - Description + * - colossalai + - `Lightning-AI/lightning-colossalai `_ + - Colossal-AI provides a collection of parallel components for you. It aims to support you to write your distributed deep learning models just like how you write your model on your laptop. `Learn more. `__ + + +---- + + ************************ Create a Custom Strategy ************************ diff --git a/docs/source-pytorch/fabric/api/fabric_args.rst b/docs/source-pytorch/fabric/api/fabric_args.rst index 2006a3d1670d9..3ee1fe9e9529f 100644 --- a/docs/source-pytorch/fabric/api/fabric_args.rst +++ b/docs/source-pytorch/fabric/api/fabric_args.rst @@ -122,6 +122,9 @@ This can result in improved performance, achieving significant speedups on moder # Default used by the Fabric fabric = Fabric(precision="32-true", devices=1) + # the same as: + fabric = Fabric(precision="32", devices=1) + # 16-bit (mixed) precision fabric = Fabric(precision="16-mixed", devices=1) diff --git a/docs/source-pytorch/fabric/fundamentals/launch.rst b/docs/source-pytorch/fabric/fundamentals/launch.rst index 95959afd83084..af766c56e4a0c 100644 --- a/docs/source-pytorch/fabric/fundamentals/launch.rst +++ b/docs/source-pytorch/fabric/fundamentals/launch.rst @@ -68,11 +68,12 @@ This is essentially the same as running ``python path/to/your/script.py``, but i --main-port, --main_port INTEGER The main port to connect to the main machine. - --precision [16-mixed|bf16-mixed|32-true|64-true] - Double precision (``64-true``), full - precision (``32-true``), half precision - (``16-mixed``) or bfloat16 precision - (``'bf16-mixed'``) + --precision [16-mixed|bf16-mixed|32-true|64-true|64|32|16|bf16] + Double precision (``64-true`` or ``64``), + full precision (``32-true`` or ``64``), half + precision (``16-mixed`` or ``16``) or + bfloat16 precision (``bf16-mixed`` or + ``bf16``) --help Show this message and exit. diff --git a/docs/source-pytorch/fabric/fundamentals/precision.rst b/docs/source-pytorch/fabric/fundamentals/precision.rst index 7a9a0c4692881..5d24b41ba4e54 100644 --- a/docs/source-pytorch/fabric/fundamentals/precision.rst +++ b/docs/source-pytorch/fabric/fundamentals/precision.rst @@ -26,6 +26,12 @@ This is how you select the precision in Fabric: # This is the default fabric = Fabric(precision="32-true") + # Also FP32 + fabric = Fabric(precision=32) + + # FP32 as well + fabric = Fabric(precision="32") + # FP16 mixed precision fabric = Fabric(precision="16-mixed) @@ -35,6 +41,12 @@ This is how you select the precision in Fabric: # Double precision fabric = Fabric(precision="64-true") + # Or + fabric = Fabric(precision="64") + + # Or + fabric = Fabric(precision=64) + The same values can also be set through the :doc:`command line interface `: diff --git a/requirements/pytorch/strategies.txt b/requirements/pytorch/strategies.txt index c8a5c9531fe3d..4db2eb301121b 100644 --- a/requirements/pytorch/strategies.txt +++ b/requirements/pytorch/strategies.txt @@ -2,3 +2,4 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment deepspeed>=0.6.0, <0.8.0 # TODO: Include 0.8.x after https://github.com/microsoft/DeepSpeed/commit/b587c7e85470329ac25df7c7c2521ff9b2833db7 gets released +lightning-colossalai==0.1.0dev diff --git a/src/lightning/app/components/serve/auto_scaler.py b/src/lightning/app/components/serve/auto_scaler.py index ed6fdbf6a4a73..da540c37baa3c 100644 --- a/src/lightning/app/components/serve/auto_scaler.py +++ b/src/lightning/app/components/serve/auto_scaler.py @@ -453,7 +453,9 @@ def _get_endpoint_info_page(self) -> Optional["APIAccessFrontend"]: # noqa: F82 try: from lightning_api_access import APIAccessFrontend except ModuleNotFoundError: - logger.warn("APIAccessFrontend not found. Please install lightning-api-access to enable the UI") + logger.warn( + "Some dependencies to run the UI are missing. To resolve, run `pip install lightning-api-access`" + ) return if is_running_in_cloud(): diff --git a/src/lightning/app/components/serve/python_server.py b/src/lightning/app/components/serve/python_server.py index a914135e2cce3..e70335a723ddb 100644 --- a/src/lightning/app/components/serve/python_server.py +++ b/src/lightning/app/components/serve/python_server.py @@ -293,7 +293,9 @@ def configure_layout(self) -> Optional["Frontend"]: try: from lightning_api_access import APIAccessFrontend except ModuleNotFoundError: - logger.warn("APIAccessFrontend not found. Please install lightning-api-access to enable the UI") + logger.warn( + "Some dependencies to run the UI are missing. To resolve, run `pip install lightning-api-access`" + ) return class_name = self.__class__.__name__ diff --git a/src/lightning/app/core/work.py b/src/lightning/app/core/work.py index abe57ab45f38c..1eb7cacbc1fa6 100644 --- a/src/lightning/app/core/work.py +++ b/src/lightning/app/core/work.py @@ -639,7 +639,10 @@ def _aggregate_status_timeout(self, statuses: List[Dict]) -> WorkStatus: return WorkStatus(**status, count=len(timeout_statuses)) def on_exit(self): - """Override this hook to add your logic when the work is exiting.""" + """Override this hook to add your logic when the work is exiting. + + Note: This hook is not guaranteed to be called when running in the cloud. + """ pass def stop(self): diff --git a/src/lightning/app/utilities/network.py b/src/lightning/app/utilities/network.py index fb3576b48d22d..db4e5d9f9afdf 100644 --- a/src/lightning/app/utilities/network.py +++ b/src/lightning/app/utilities/network.py @@ -66,7 +66,7 @@ def find_free_network_port() -> int: def _find_free_network_port_cloudspace(): """Finds a free port in the exposed range when running in a cloudspace.""" for port in range( - constants.APP_SERVER_PORT, + constants.APP_SERVER_PORT + 1, # constants.APP_SERVER_PORT is reserved for the app server constants.APP_SERVER_PORT + constants.LIGHTNING_CLOUDSPACE_EXPOSED_PORT_COUNT, ): if port in _reserved_ports: diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 3e826b52d57f4..1b9543d63e9cc 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -28,7 +28,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `DataParallelStrategy.get_module_state_dict()` and `DDPStrategy.get_module_state_dict()` now correctly extracts the state dict without keys prefixed with 'module' ([#16487](https://github.com/Lightning-AI/lightning/pull/16487)) - - "Native" suffix removal ([#16490](https://github.com/Lightning-AI/lightning/pull/16490)) * `strategy="fsdp_full_shard_offload"` is now `strategy="fsdp_cpu_offload"` * `lightning.fabric.plugins.precision.native_amp` is now `lightning.fabric.plugins.precision.amp` @@ -36,6 +35,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Enabled all shorthand strategy names that can be supported in the CLI ([#16485](https://github.com/Lightning-AI/lightning/pull/16485)) +- Renamed `strategy='tpu_spawn'` to `strategy='xla'` and `strategy='tpu_spawn_debug'` to `strategy='xla_debug'` ([#16781](https://github.com/Lightning-AI/lightning/pull/16781)) + + +- Changed arguments for precision settings (from [64|32|16|bf16] to ["64-true"|"32-true"|"16-mixed"|"bf16-mixed"]) ([#16767](https://github.com/Lightning-AI/lightning/pull/16767)) - Changed arguments for precision settings (from [64|32|16|bf16] to ["64-true"|"32-true"|"16-mixed"|"bf16-mixed"]) ([#16767](https://github.com/Lightning-AI/lightning/pull/16767)) diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index f6b5bf7112d6e..4671a75da7577 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -21,7 +21,7 @@ from typing_extensions import get_args from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator -from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR +from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS from lightning.fabric.strategies import STRATEGY_REGISTRY from lightning.fabric.utilities.device_parser import _parse_gpu_ids @@ -107,11 +107,11 @@ def _get_supported_strategies() -> List[str]: ) @click.option( "--precision", - type=click.Choice(get_args(_PRECISION_INPUT_STR)), + type=click.Choice(get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_STR_ALIAS)), default="32-true", help=( - "Double precision (``64-true``), full precision (``32-true``), half precision (``16-mixed``) or " - "bfloat16 precision (``'bf16-mixed'``)" + "Double precision (``64-true`` or ``64``), full precision (``32-true`` or ``64``), " + "half precision (``16-mixed`` or ``16``) or bfloat16 precision (``bf16-mixed`` or ``bf16``)" ), ) @click.argument("script_args", nargs=-1, type=click.UNPROCESSED) diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 05e76f926d097..c8276f6ef8459 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -46,8 +46,8 @@ _PRECISION_INPUT, _PRECISION_INPUT_INT, _PRECISION_INPUT_STR, - _PRECISION_INPUT_STR_LEGACY, - _PRECISION_INPUT_STR_LEGACY_CONVERSION, + _PRECISION_INPUT_STR_ALIAS, + _PRECISION_INPUT_STR_ALIAS_CONVERSION, ) from lightning.fabric.strategies import ( DeepSpeedStrategy, @@ -226,22 +226,7 @@ def _check_config_and_set_final_flags( self._accelerator_flag = accelerator - supported_precision = ( - get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + get_args(_PRECISION_INPUT_STR_LEGACY) - ) - if precision not in supported_precision: - raise ValueError(f"Precision {repr(precision)} is invalid. Allowed precision values: {supported_precision}") - - precision = str(precision) # convert int flags to str here to enable the legacy-conversion below - - if precision in get_args(_PRECISION_INPUT_STR_LEGACY): - rank_zero_warn( - f"{precision} is supported for historical reasons but its usage is discouraged. " - f"Please set your precision to {_PRECISION_INPUT_STR_LEGACY_CONVERSION[precision]} instead!" - ) - precision = _PRECISION_INPUT_STR_LEGACY_CONVERSION[precision] - - self._precision_input = cast(_PRECISION_INPUT_STR, precision) + self._precision_input = _convert_precision_to_unified_args(precision) if plugins: plugins_flags_types: Dict[str, int] = Counter() @@ -403,7 +388,7 @@ def _choose_and_init_cluster_environment(self) -> ClusterEnvironment: def _choose_strategy(self) -> Union[Strategy, str]: if self._accelerator_flag == "tpu": if self._parallel_devices and len(self._parallel_devices) > 1: - return "tpu_spawn" + return "xla" else: # TODO: lazy initialized device, then here could be self._strategy_flag = "single_tpu_device" return SingleTPUStrategy(device=self._parallel_devices[0]) # type: ignore @@ -579,3 +564,22 @@ def _argument_from_env(name: str, current: Any, default: Any) -> Any: if env_value is None: return current return env_value + + +def _convert_precision_to_unified_args(precision: _PRECISION_INPUT) -> _PRECISION_INPUT_STR: + supported_precision = ( + get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + get_args(_PRECISION_INPUT_STR_ALIAS) + ) + if precision not in supported_precision: + raise ValueError(f"Precision {repr(precision)} is invalid. Allowed precision values: {supported_precision}") + + precision = str(precision) # convert int flags to str here to enable the legacy-conversion below + + if precision in get_args(_PRECISION_INPUT_STR_ALIAS): + if str(precision)[:2] not in ("32", "64"): + rank_zero_warn( + f"{precision} is supported for historical reasons but its usage is discouraged. " + f"Please set your precision to {_PRECISION_INPUT_STR_ALIAS_CONVERSION[precision]} instead!" + ) + precision = _PRECISION_INPUT_STR_ALIAS_CONVERSION[precision] + return cast(_PRECISION_INPUT_STR, precision) diff --git a/src/lightning/fabric/plugins/precision/deepspeed.py b/src/lightning/fabric/plugins/precision/deepspeed.py index 2d813da1ec657..44195f823e04f 100644 --- a/src/lightning/fabric/plugins/precision/deepspeed.py +++ b/src/lightning/fabric/plugins/precision/deepspeed.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, cast, Literal, TYPE_CHECKING +from typing import Any, Literal, TYPE_CHECKING import torch from torch import Tensor @@ -48,7 +48,7 @@ def __init__(self, precision: _PRECISION_INPUT) -> None: f"`precision={precision!r})` is not supported in DeepSpeed." f" `precision` must be one of: {supported_precision}." ) - self.precision = cast(_PRECISION_INPUT, precision) + self.precision = precision def convert_input(self, data: Tensor) -> Tensor: precision_to_type = {"bf16-mixed": torch.bfloat16, "16-mixed": torch.float16, "32-true": torch.float32} diff --git a/src/lightning/fabric/plugins/precision/precision.py b/src/lightning/fabric/plugins/precision/precision.py index 2c8e3625cdc86..e1add043662fe 100644 --- a/src/lightning/fabric/plugins/precision/precision.py +++ b/src/lightning/fabric/plugins/precision/precision.py @@ -23,10 +23,10 @@ from lightning.fabric.utilities.types import _PARAMETERS, Optimizable _PRECISION_INPUT_INT = Literal[64, 32, 16] -_PRECISION_INPUT_STR_LEGACY_CONVERSION = {"64": "64-true", "32": "32-true", "16": "16-mixed", "bf16": "bf16-mixed"} -_PRECISION_INPUT_STR_LEGACY = Literal["64", "32", "16", "bf16"] +_PRECISION_INPUT_STR_ALIAS_CONVERSION = {"64": "64-true", "32": "32-true", "16": "16-mixed", "bf16": "bf16-mixed"} +_PRECISION_INPUT_STR_ALIAS = Literal["64", "32", "16", "bf16"] _PRECISION_INPUT_STR = Literal["16-mixed", "bf16-mixed", "32-true", "64-true"] -_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_LEGACY] +_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS] class Precision: diff --git a/src/lightning/fabric/strategies/xla.py b/src/lightning/fabric/strategies/xla.py index ffaec51891412..66624239a714c 100644 --- a/src/lightning/fabric/strategies/xla.py +++ b/src/lightning/fabric/strategies/xla.py @@ -13,7 +13,7 @@ # limitations under the License. import io import os -from typing import Any, Dict, List, Mapping, Optional, Sequence, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union import torch from torch import Tensor @@ -30,7 +30,6 @@ from lightning.fabric.strategies import ParallelStrategy from lightning.fabric.strategies.launchers.xla import _XLALauncher from lightning.fabric.strategies.strategy import TBroadcast -from lightning.fabric.utilities.apply_func import apply_to_collection from lightning.fabric.utilities.data import has_len from lightning.fabric.utilities.rank_zero import rank_zero_only from lightning.fabric.utilities.types import _PATH, ReduceOp @@ -210,8 +209,6 @@ def remove_checkpoint(self, filepath: _PATH) -> None: @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: - # TODO(fabric): Deprecate the name "tpu_spawn" through the connector - strategy_registry.register("tpu_spawn", cls, description=cls.__class__.__name__) strategy_registry.register("xla", cls, description=cls.__class__.__name__) def _set_world_ranks(self) -> None: @@ -222,12 +219,9 @@ def _set_world_ranks(self) -> None: rank_zero_only.rank = self.cluster_environment.global_rank() @staticmethod - def _validate_dataloader(dataloaders: DataLoader) -> None: - def check_has_len(dataloader: DataLoader) -> None: - if not has_len(dataloader): - raise TypeError( - "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`." - " HINT: You can mock the length on your dataset to bypass this MisconfigurationException." - ) - - apply_to_collection(dataloaders, dtype=object, wrong_dtype=(Sequence, Mapping), function=check_has_len) + def _validate_dataloader(dataloader: object) -> None: + if not has_len(dataloader): + raise TypeError( + "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`." + " HINT: You can mock the length on your dataset to bypass this error." + ) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index a85d8d73c816d..f39ecdb2bdc25 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -42,7 +42,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a new method `Strategy.on_exception` to the strategy base interface ([#16646](https://github.com/Lightning-AI/lightning/pull/16646)) -- Added "sequential" mode support to `CombinedLoader` to consume multiple iterables in sequence ([#16743](https://github.com/Lightning-AI/lightning/pull/16743)) +- Added "sequential" mode support to `CombinedLoader` to consume multiple iterables in sequence ([#16743](https://github.com/Lightning-AI/lightning/pull/16743), [#16784](https://github.com/Lightning-AI/lightning/pull/16784)) ### Changed @@ -89,6 +89,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The `dataloader_idx` argument is now optional for the `on_{validation,test,predict}_batch_{start,end}` hooks. Remove it or default it to 0 if you don't use multiple dataloaders ([#16753](https://github.com/Lightning-AI/lightning/pull/16753)) + +- Renamed `TPUSpawnStrategy` to `XLAStrategy` ([#16781](https://github.com/Lightning-AI/lightning/pull/16781)) + +- Renamed `strategy='tpu_spawn'` to `strategy='xla'` and `strategy='tpu_spawn_debug'` to `strategy='xla_debug'` ([#16781](https://github.com/Lightning-AI/lightning/pull/16781)) + + ### Deprecated - @@ -297,6 +303,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed the unused `lightning.pytorch.utilities.supporters.{SharedCycleIteratorState,CombinedLoaderIterator}` classes ([#16714](https://github.com/Lightning-AI/lightning/pull/16714)) +- Removed `ProgressBarBase.{train_batch_idx,val_batch_idx,test_batch_idx,predict_batch_idx}` properties ([#16760](https://github.com/Lightning-AI/lightning/pull/16760)) + + + - Removed the `Trainer(track_grad_norm=...)` argument ([#16745](https://github.com/Lightning-AI/lightning/pull/16745)) @@ -306,6 +316,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed the `QuantizationAwareTraining` callback ([#16750](https://github.com/Lightning-AI/lightning/pull/16750)) +- Removed the `ColossalAIStrategy` and `ColossalAIPrecisionPlugin` in favor of the new [lightning-colossalai](https://github.com/Lightning-AI/lightning-colossalai) package ([#16757](https://github.com/Lightning-AI/lightning/pull/16757), [#16778](https://github.com/Lightning-AI/lightning/pull/16778)) + + ### Fixed - Fixed an attribute error and improved input validation for invalid strategy types being passed to Trainer ([#16693](https://github.com/Lightning-AI/lightning/pull/16693)) @@ -314,6 +327,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed early stopping triggering extra validation runs after reaching `min_epochs` or `min_steps` ([#16719](https://github.com/Lightning-AI/lightning/pull/16719)) +- Fixed bug where `set_epoch` was not called for prediction dataloaders ([#16785](https://github.com/Lightning-AI/lightning/pull/16785)) + ## [1.9.1] - 2023-02-10 ### Fixed diff --git a/src/lightning/pytorch/callbacks/progress/base.py b/src/lightning/pytorch/callbacks/progress/base.py index 041783ca68b1b..c0492d5bf314e 100644 --- a/src/lightning/pytorch/callbacks/progress/base.py +++ b/src/lightning/pytorch/callbacks/progress/base.py @@ -76,49 +76,6 @@ def test_description(self) -> str: def predict_description(self) -> str: return "Predicting" - @property - def _val_processed(self) -> int: - # use total in case validation runs more than once per training epoch - return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed - - @property - def train_batch_idx(self) -> int: - """The number of batches processed during training. - - Use this to update your progress bar. - """ - return self.trainer.fit_loop.epoch_loop.batch_progress.current.processed - - @property - def val_batch_idx(self) -> int: - """The number of batches processed during validation. - - Use this to update your progress bar. - """ - if self.trainer.state.fn == "fit": - loop = self.trainer.fit_loop.epoch_loop.val_loop - else: - loop = self.trainer.validate_loop - - current_batch_idx = loop.epoch_loop.batch_progress.current.processed - return current_batch_idx - - @property - def test_batch_idx(self) -> int: - """The number of batches processed during testing. - - Use this to update your progress bar. - """ - return self.trainer.test_loop.epoch_loop.batch_progress.current.processed - - @property - def predict_batch_idx(self) -> int: - """The number of batches processed during prediction. - - Use this to update your progress bar. - """ - return self.trainer.predict_loop.epoch_loop.batch_progress.current.processed - @property def total_train_batches(self) -> Union[int, float]: """The total number of training batches, which may change from epoch to epoch. diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 157350b63288c..a8f77a3a91dfc 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -485,7 +485,7 @@ def on_predict_batch_start( def on_train_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int ) -> None: - self._update(self.train_progress_bar_id, self.train_batch_idx) + self._update(self.train_progress_bar_id, batch_idx + 1) self._update_metrics(trainer, pl_module) self.refresh() @@ -504,9 +504,9 @@ def on_validation_batch_end( if self.is_disabled: return if trainer.sanity_checking: - self._update(self.val_sanity_progress_bar_id, self.val_batch_idx) + self._update(self.val_sanity_progress_bar_id, batch_idx + 1) elif self.val_progress_bar_id is not None: - self._update(self.val_progress_bar_id, self.val_batch_idx) + self._update(self.val_progress_bar_id, batch_idx + 1) self.refresh() def on_test_batch_end( @@ -521,7 +521,7 @@ def on_test_batch_end( if self.is_disabled: return assert self.test_progress_bar_id is not None - self._update(self.test_progress_bar_id, self.test_batch_idx) + self._update(self.test_progress_bar_id, batch_idx + 1) self.refresh() def on_predict_batch_end( @@ -536,7 +536,7 @@ def on_predict_batch_end( if self.is_disabled: return assert self.predict_progress_bar_id is not None - self._update(self.predict_progress_bar_id, self.predict_batch_idx) + self._update(self.predict_progress_bar_id, batch_idx + 1) self.refresh() def _get_train_description(self, current_epoch: int) -> str: diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index 121dccd0327bf..fe57b79c4fd02 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -17,6 +17,8 @@ import sys from typing import Any, Dict, Optional, Union +from lightning.pytorch.utilities.types import STEP_OUTPUT + # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed @@ -190,7 +192,6 @@ def init_train_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for training.""" bar = Tqdm( desc=self.train_description, - initial=self.train_batch_idx, position=(2 * self.process_position), disable=self.is_disabled, leave=True, @@ -204,7 +205,6 @@ def init_predict_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for predicting.""" bar = Tqdm( desc=self.predict_description, - initial=self.train_batch_idx, position=(2 * self.process_position), disable=self.is_disabled, leave=True, @@ -256,10 +256,12 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: self.train_progress_bar.initial = 0 self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}") - def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None: - current = self.train_batch_idx - if self._should_update(current, self.train_progress_bar.total): - _update_n(self.train_progress_bar, current) + def on_train_batch_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int + ) -> None: + n = batch_idx + 1 + if self._should_update(n, self.train_progress_bar.total): + _update_n(self.train_progress_bar, n) self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -289,9 +291,18 @@ def on_validation_batch_start( desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}") - def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None: - if self._should_update(self.val_batch_idx, self.val_progress_bar.total): - _update_n(self.val_progress_bar, self.val_batch_idx) + def on_validation_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: Optional[STEP_OUTPUT], + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + n = batch_idx + 1 + if self._should_update(n, self.val_progress_bar.total): + _update_n(self.val_progress_bar, n) def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if self._train_progress_bar is not None and trainer.state.fn == "fit": @@ -317,9 +328,18 @@ def on_test_batch_start( self.test_progress_bar.initial = 0 self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}") - def on_test_batch_end(self, *_: Any) -> None: - if self._should_update(self.test_batch_idx, self.test_progress_bar.total): - _update_n(self.test_progress_bar, self.test_batch_idx) + def on_test_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: Optional[STEP_OUTPUT], + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + n = batch_idx + 1 + if self._should_update(n, self.test_progress_bar.total): + _update_n(self.test_progress_bar, n) def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.test_progress_bar.close() @@ -343,9 +363,18 @@ def on_predict_batch_start( self.predict_progress_bar.initial = 0 self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}") - def on_predict_batch_end(self, *_: Any) -> None: - if self._should_update(self.predict_batch_idx, self.predict_progress_bar.total): - _update_n(self.predict_progress_bar, self.predict_batch_idx) + def on_predict_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + n = batch_idx + 1 + if self._should_update(n, self.predict_progress_bar.total): + _update_n(self.predict_progress_bar, n) def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.predict_progress_bar.close() diff --git a/src/lightning/pytorch/loops/dataloader/evaluation_loop.py b/src/lightning/pytorch/loops/dataloader/evaluation_loop.py index 24914733e3ec7..3e30da9adaeb1 100644 --- a/src/lightning/pytorch/loops/dataloader/evaluation_loop.py +++ b/src/lightning/pytorch/loops/dataloader/evaluation_loop.py @@ -74,12 +74,6 @@ def dataloaders(self) -> Sequence[DataLoader]: return [] return dataloaders - @property - def prefetch_batches(self) -> int: - batches = self.trainer.num_test_batches if self.trainer.testing else self.trainer.num_val_batches - is_unsized = batches[self.current_dataloader_idx] == float("inf") - return int(is_unsized) - @property def done(self) -> bool: """Returns whether all dataloaders are processed or evaluation should be skipped altogether.""" @@ -126,7 +120,7 @@ def reset(self) -> None: def on_run_start(self) -> None: """Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start`` hooks.""" - self._data_fetcher = _select_data_fetcher(self.trainer, prefetch_batches=self.prefetch_batches) + self._data_fetcher = _select_data_fetcher(self.trainer) # hook self._on_evaluation_model_eval() diff --git a/src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py b/src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py index 1e0c35ca29761..007068febf2ff 100644 --- a/src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py +++ b/src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py @@ -114,11 +114,12 @@ def advance( Raises: StopIteration: If the current batch is None """ - if not isinstance(data_fetcher, _DataLoaderIterDataFetcher): - batch_idx = self.batch_progress.current.ready - batch = next(data_fetcher) - else: - batch_idx, batch = next(data_fetcher) + batch_idx = ( + data_fetcher.fetched + if isinstance(data_fetcher, _DataLoaderIterDataFetcher) + else self.batch_progress.current.ready + ) + batch = next(data_fetcher) self.batch_progress.is_last_batch = data_fetcher.done dataloader_idx = kwargs.get("dataloader_idx", 0) diff --git a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py index cf7f38707ca46..26d801bf21ad2 100644 --- a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py @@ -15,8 +15,6 @@ from collections import OrderedDict from typing import Any, Dict, Optional, Union -import torch - import lightning.pytorch as pl from lightning.pytorch import loops # import as loops to avoid circular imports from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher @@ -188,11 +186,8 @@ def advance(self, data_fetcher: _DataFetcher) -> None: # we are going to train first so the val loop does not need to restart self.val_loop.restarting = False - if not isinstance(data_fetcher, _DataLoaderIterDataFetcher): - batch_idx = self.batch_idx + 1 - batch = next(data_fetcher) - else: - batch_idx, batch = next(data_fetcher) + batch_idx = data_fetcher.fetched if isinstance(data_fetcher, _DataLoaderIterDataFetcher) else self.batch_idx + 1 + batch = next(data_fetcher) self.batch_progress.is_last_batch = data_fetcher.done trainer = self.trainer @@ -284,8 +279,7 @@ def _run_validation(self) -> None: # reload dataloaders self.val_loop._reload_evaluation_dataloaders() - with torch.no_grad(): - self.val_loop.run() + self.val_loop.run() def _accumulated_batches_reached(self) -> bool: """Determine if accumulation will be finished by the end of the current batch.""" diff --git a/src/lightning/pytorch/loops/fetchers.py b/src/lightning/pytorch/loops/fetchers.py index 8d983b6e91f04..ee29e7b69c0f7 100644 --- a/src/lightning/pytorch/loops/fetchers.py +++ b/src/lightning/pytorch/loops/fetchers.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterable, Iterator, List, Optional, Sized, Tuple +from typing import Any, Iterable, Iterator, List, Optional, Sized, Tuple, Union from torch.utils.data.dataloader import DataLoader from lightning.fabric.utilities.data import has_len -from lightning.pytorch.trainer.supporters import _shutdown_workers_and_reset_iterator, CombinedLoader +from lightning.pytorch.trainer.supporters import _Sequential, _shutdown_workers_and_reset_iterator, CombinedLoader from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -81,7 +81,7 @@ class _PrefetchDataFetcher(_DataFetcher): Args: prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track - whether a batch is the last one (available with :attr:`self.done`) under any training setup. + whether a batch is the last one (available with :attr:`self.done`) when the length is not available. """ def __init__(self, prefetch_batches: int = 1) -> None: @@ -98,6 +98,10 @@ def setup(self, dataloader: Iterable) -> None: def __iter__(self) -> "_PrefetchDataFetcher": super().__iter__() + if self._has_len: + # ignore pre-fetching, it's not necessary + return self + # prefetch batches to know when the iterator will be exhausted in advance iterator = self.dataloader_iter assert iterator is not None for _ in range(self.prefetch_batches): @@ -143,7 +147,7 @@ def _fetch_next_batch(self, iterator: Iterator) -> None: finally: self._stop_profiler() self.fetched += 1 - if not self.prefetch_batches and self._has_len: + if self._has_len: # when we don't prefetch but the dataloader is sized, we use the length for `done` dataloader = self.dataloader assert isinstance(dataloader, Sized) # `_has_len` is True @@ -171,15 +175,24 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None: def __iter__(self) -> "_DataLoaderIterDataFetcher": super().__iter__() - iterator = self.dataloader_iter - assert iterator is not None self.iterator = iter(_DataFetcherWrapper(self)) return self - def __next__(self) -> Tuple[int, Iterator]: - if not self.done: - return self.fetched, self.iterator - raise StopIteration + def __next__(self) -> Union["_DataFetcherWrapper", Tuple["_DataFetcherWrapper", int, int]]: + if self.done: + raise StopIteration + assert isinstance(self.iterator, _DataFetcherWrapper) + if self._is_sequential: + sequential_mode = self.dataloader._iterator + assert isinstance(sequential_mode, _Sequential) + batch_idx = sequential_mode._idx + dataloader_idx = sequential_mode._iterator_idx + return self.iterator, batch_idx, dataloader_idx + return self.iterator + + @property + def _is_sequential(self) -> bool: + return isinstance(self.dataloader, CombinedLoader) and self.dataloader._mode == "sequential" class _DataFetcherWrapper(Iterator): @@ -187,4 +200,10 @@ def __init__(self, data_fetcher: _DataLoaderIterDataFetcher) -> None: self.data_fetcher = data_fetcher def __next__(self) -> Any: - return super(_DataLoaderIterDataFetcher, self.data_fetcher).__next__() + out = super(_DataLoaderIterDataFetcher, self.data_fetcher).__next__() + if self.data_fetcher._is_sequential: + # avoid breaking change with sequential mode and dataloader_iter. this is okay because + # dataloader_iter + sequential + multiple dataloaders is not supported so the `*_step(..., batch_idx)` value + # and the batch_index we are excluding here will match + return out[0] + return out diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 73fc7a37175e5..ef9f44c13f465 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -121,11 +121,6 @@ def restarting(self, restarting: bool) -> None: restarting = restarting and epoch_unfinished or self._iteration_based_training() _Loop.restarting.fset(self, restarting) # call the parent setter - @property - def prefetch_batches(self) -> int: - is_unsized = self.trainer.num_training_batches == float("inf") - return int(is_unsized) - @property def _skip_backward(self) -> bool: """Determines whether the loop will skip backward during automatic optimization.""" @@ -219,7 +214,7 @@ def on_run_start(self) -> None: if self.epoch_loop._should_check_val_epoch(): self.epoch_loop.val_loop._reload_evaluation_dataloaders() - self._data_fetcher = _select_data_fetcher(trainer, self.prefetch_batches) + self._data_fetcher = _select_data_fetcher(trainer) self._is_fresh_start_epoch = True self._results.to(device=trainer.lightning_module.device) diff --git a/src/lightning/pytorch/loops/utilities.py b/src/lightning/pytorch/loops/utilities.py index d61afe4b98d58..b6f61b75036b6 100644 --- a/src/lightning/pytorch/loops/utilities.py +++ b/src/lightning/pytorch/loops/utilities.py @@ -136,7 +136,7 @@ def _set_sampler_epoch(dataloader: Iterable, epoch: int) -> None: sampler.set_epoch(epoch) -def _select_data_fetcher(trainer: "pl.Trainer", prefetch_batches: int = 0) -> _DataFetcher: +def _select_data_fetcher(trainer: "pl.Trainer") -> _DataFetcher: lightning_module = trainer.lightning_module if trainer.testing: step_fx_name = "test_step" @@ -153,7 +153,7 @@ def _select_data_fetcher(trainer: "pl.Trainer", prefetch_batches: int = 0) -> _D "this signature is experimental and the behavior is subject to change." ) return _DataLoaderIterDataFetcher() - return _PrefetchDataFetcher(prefetch_batches=prefetch_batches) + return _PrefetchDataFetcher() def _no_grad_context(loop_run: Callable) -> Callable: diff --git a/src/lightning/pytorch/overrides/distributed.py b/src/lightning/pytorch/overrides/distributed.py index 3f8f98afb4a13..8830c223c622a 100644 --- a/src/lightning/pytorch/overrides/distributed.py +++ b/src/lightning/pytorch/overrides/distributed.py @@ -135,3 +135,7 @@ def batch_size(self) -> int: @property def sampler(self) -> Union[Sampler, Iterable]: return self._sampler.sampler + + def set_epoch(self, epoch: int) -> None: + if hasattr(self._sampler, "set_epoch"): + self._sampler.set_epoch(epoch) diff --git a/src/lightning/pytorch/strategies/__init__.py b/src/lightning/pytorch/strategies/__init__.py index ed48d873b6160..0cc1dc35b4363 100644 --- a/src/lightning/pytorch/strategies/__init__.py +++ b/src/lightning/pytorch/strategies/__init__.py @@ -23,8 +23,8 @@ from lightning.pytorch.strategies.single_hpu import SingleHPUStrategy # noqa: F401 from lightning.pytorch.strategies.single_tpu import SingleTPUStrategy # noqa: F401 from lightning.pytorch.strategies.strategy import Strategy # noqa: F401 -from lightning.pytorch.strategies.tpu_spawn import TPUSpawnStrategy # noqa: F401 from lightning.pytorch.strategies.utils import _call_register_strategies +from lightning.pytorch.strategies.xla import XLAStrategy # noqa: F401 _STRATEGIES_BASE_MODULE = "lightning.pytorch.strategies" StrategyRegistry = _StrategyRegistry() diff --git a/src/lightning/pytorch/strategies/launchers/xla.py b/src/lightning/pytorch/strategies/launchers/xla.py index 2670c860087eb..692b69f9bfb3c 100644 --- a/src/lightning/pytorch/strategies/launchers/xla.py +++ b/src/lightning/pytorch/strategies/launchers/xla.py @@ -47,7 +47,7 @@ class _XLALauncher(_MultiProcessingLauncher): strategy: A reference to the strategy that is used together with this launcher """ - def __init__(self, strategy: "pl.strategies.TPUSpawnStrategy") -> None: + def __init__(self, strategy: "pl.strategies.XLAStrategy") -> None: if not _XLA_AVAILABLE: raise ModuleNotFoundError(str(_XLA_AVAILABLE)) super().__init__(strategy=strategy, start_method="fork") diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index bb1ce96d62ebe..b3a2feb030343 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -14,13 +14,12 @@ import contextlib import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, TypeVar, Union import torch from torch import Tensor from torch.nn import Module from torch.optim import Optimizer -from torch.utils.data import DataLoader import lightning.pytorch as pl from lightning.fabric.plugins import CheckpointIO @@ -405,7 +404,7 @@ def validation_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: def test_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: return output - def process_dataloader(self, dataloader: DataLoader) -> DataLoader: + def process_dataloader(self, dataloader: Iterable) -> Iterable: """Wraps the dataloader if necessary. Args: diff --git a/src/lightning/pytorch/strategies/tpu_spawn.py b/src/lightning/pytorch/strategies/xla.py similarity index 82% rename from src/lightning/pytorch/strategies/tpu_spawn.py rename to src/lightning/pytorch/strategies/xla.py index bcb5d669b7186..9bdf53fb032e8 100644 --- a/src/lightning/pytorch/strategies/tpu_spawn.py +++ b/src/lightning/pytorch/strategies/xla.py @@ -13,13 +13,11 @@ # limitations under the License. import io import os -from typing import Any, Dict, List, Mapping, Optional, Sequence, TYPE_CHECKING, Union +from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Union import torch -from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor from torch.nn import Module -from torch.utils.data import DataLoader import lightning.pytorch as pl from lightning.fabric.accelerators.tpu import _XLA_AVAILABLE @@ -34,12 +32,10 @@ from lightning.pytorch.strategies.ddp_spawn import DDPSpawnStrategy from lightning.pytorch.strategies.launchers.xla import _XLALauncher from lightning.pytorch.strategies.strategy import TBroadcast -from lightning.pytorch.trainer.connectors.data_connector import DataConnector from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities import find_shared_parameters, set_shared_parameters -from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_zero_only -from lightning.pytorch.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS +from lightning.pytorch.utilities.types import STEP_OUTPUT if TYPE_CHECKING and _XLA_AVAILABLE: from torch_xla.distributed.parallel_loader import MpDeviceLoader @@ -47,11 +43,11 @@ MpDeviceLoader = None -class TPUSpawnStrategy(DDPSpawnStrategy): +class XLAStrategy(DDPSpawnStrategy): """Strategy for training multiple TPU devices using the :func:`torch_xla.distributed.xla_multiprocessing.spawn` method.""" - strategy_name = "tpu_spawn" + strategy_name = "xla" def __init__( self, @@ -98,34 +94,14 @@ def root_device(self) -> torch.device: return xm.xla_device() @staticmethod - def _validate_dataloader(dataloaders: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]) -> None: - def check_has_len(dataloader: DataLoader) -> None: - if not has_len(dataloader): - raise MisconfigurationException( - "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`." - " HINT: You can mock the length on your dataset to bypass this MisconfigurationException." - ) - - apply_to_collection(dataloaders, dtype=object, wrong_dtype=(Sequence, Mapping), function=check_has_len) - - @staticmethod - def _validate_patched_dataloaders(model: "pl.LightningModule") -> None: - """Validate and fail fast if the dataloaders were passed directly to fit.""" - connector: DataConnector = model.trainer._data_connector - sources = ( - connector._train_dataloader_source, - connector._val_dataloader_source, - connector._test_dataloader_source, - connector._predict_dataloader_source, - ) - for source in sources: - if not source.is_module(): - assert source.instance is not None - assert not isinstance(source.instance, (pl.LightningModule, pl.LightningDataModule)) - TPUSpawnStrategy._validate_dataloader(source.instance) + def _validate_dataloader(dataloader: object) -> None: + if not has_len(dataloader): + raise TypeError( + "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`." + " HINT: You can mock the length on your dataset to bypass this error." + ) def connect(self, model: "pl.LightningModule") -> None: - TPUSpawnStrategy._validate_patched_dataloaders(model) import torch_xla.distributed.xla_multiprocessing as xmp self.wrapped_model = xmp.MpModelWrapper(_LightningModuleWrapperBase(model)) @@ -166,8 +142,8 @@ def is_distributed(self) -> bool: return (xenv.HOST_WORLD_SIZE in os.environ) and self.world_size != 1 - def process_dataloader(self, dataloader: DataLoader) -> "MpDeviceLoader": - TPUSpawnStrategy._validate_dataloader(dataloader) + def process_dataloader(self, dataloader: Iterable) -> "MpDeviceLoader": + XLAStrategy._validate_dataloader(dataloader) from torch_xla.distributed.parallel_loader import MpDeviceLoader if isinstance(dataloader, MpDeviceLoader): @@ -216,7 +192,7 @@ def reduce( invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") if invalid_reduce_op or invalid_reduce_op_str: raise ValueError( - "Currently, the TPUSpawnStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:" + "Currently, the XLAStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:" f" {reduce_op}" ) @@ -317,10 +293,7 @@ def teardown(self) -> None: @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: - strategy_registry.register( - "tpu_spawn_debug", cls, description="TPUSpawn Strategy with `debug` as True", debug=True - ) - + strategy_registry.register("xla_debug", cls, description="XLA strategy with `debug` as True", debug=True) strategy_registry.register( cls.strategy_name, cls, diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index f98535050324a..dabf162875597 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -72,7 +72,7 @@ SingleTPUStrategy, Strategy, StrategyRegistry, - TPUSpawnStrategy, + XLAStrategy, ) from lightning.pytorch.strategies.ddp_spawn import _DDP_FORK_ALIASES from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -459,7 +459,7 @@ def _choose_strategy(self) -> Union[Strategy, str]: return SingleHPUStrategy(device=torch.device("hpu")) if self._accelerator_flag == "tpu": if self._parallel_devices and len(self._parallel_devices) > 1: - return TPUSpawnStrategy.strategy_name + return XLAStrategy.strategy_name else: # TODO: lazy initialized device, then here could be self._strategy_flag = "single_tpu_device" return SingleTPUStrategy(device=self._parallel_devices[0]) # type: ignore @@ -634,10 +634,10 @@ def _lazy_init_strategy(self) -> None: # TODO: should be moved to _check_strategy_and_fallback(). # Current test check precision first, so keep this check here to meet error order if isinstance(self.accelerator, TPUAccelerator) and not isinstance( - self.strategy, (SingleTPUStrategy, TPUSpawnStrategy) + self.strategy, (SingleTPUStrategy, XLAStrategy) ): raise ValueError( - "The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `TPUSpawnStrategy`," + "The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `XLAStrategy`," f" found {self.strategy.__class__.__name__}." ) @@ -661,7 +661,7 @@ def is_distributed(self) -> bool: FSDPStrategy, DDPSpawnStrategy, DeepSpeedStrategy, - TPUSpawnStrategy, + XLAStrategy, HPUParallelStrategy, ) is_distributed = isinstance(self.strategy, distributed_strategy) diff --git a/src/lightning/pytorch/trainer/states.py b/src/lightning/pytorch/trainer/states.py index 336a6c08b2778..73b7cb71dcf82 100644 --- a/src/lightning/pytorch/trainer/states.py +++ b/src/lightning/pytorch/trainer/states.py @@ -63,13 +63,11 @@ class RunningStage(LightningEnum): @property def evaluating(self) -> bool: - return self in (self.VALIDATING, self.TESTING) + return self in (self.VALIDATING, self.TESTING, self.SANITY_CHECKING) @property def dataloader_prefix(self) -> Optional[str]: - if self == self.SANITY_CHECKING: - return None - if self == self.VALIDATING: + if self in (self.VALIDATING, self.SANITY_CHECKING): return "val" return self.value diff --git a/src/lightning/pytorch/trainer/supporters.py b/src/lightning/pytorch/trainer/supporters.py index 2c3872e239dd5..ffa56538adc60 100644 --- a/src/lightning/pytorch/trainer/supporters.py +++ b/src/lightning/pytorch/trainer/supporters.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Iterable -from typing import Any, Callable, Iterator, List, Literal, Optional, Sized, Tuple, Type, TypeVar +from typing import Any, Callable, Iterator, List, Literal, Optional, Sized, Tuple, Type, TypeVar, Union from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter from typing_extensions import Self, TypedDict @@ -74,27 +74,47 @@ def __next__(self) -> List: return [next(it) for it in self.iterators] -class _Sequential(_ModeIterator[Tuple[int, Any]]): - def __init__(self, iterables: List[Iterable]) -> None: +class _Sequential(_ModeIterator[Tuple[Any, int, int]]): + def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, float]]] = None) -> None: super().__init__(iterables) self._iterator_idx = 0 # what would be dataloader_idx self._idx = 0 # what would be batch_idx + self.limits = limits - def __next__(self) -> Tuple[int, Any]: + @property + def limits(self) -> Optional[List[Union[int, float]]]: + """Optional limits per iterator.""" + return self._limits + + @limits.setter + def limits(self, limits: Optional[List[Union[int, float]]]) -> None: + if limits is not None and len(limits) != len(self.iterables): + raise ValueError( + f"Mismatch in number of limits ({len(limits)}) and number of iterables ({len(self.iterables)})" + ) + self._limits = limits + + def __next__(self) -> Tuple[Any, int, int]: n = len(self.iterators) - if n == 0: + if n == 0 or self._iterator_idx >= n: raise StopIteration + + # if limits are set, go to the correct iterator + if self.limits is not None: + while self.limits[self._iterator_idx] <= self._idx: + self._use_next_iterator() + if self._iterator_idx >= n: + raise StopIteration + try: out = next(self.iterators[self._iterator_idx]) index = self._idx self._idx += 1 - # the return is enumerated by default - return index, out + # batch, batch_idx, dataloader_idx + return out, index, self._iterator_idx except StopIteration: - self._iterator_idx += 1 - self._idx = 0 - if self._iterator_idx >= n: - raise + # try the next iterator + self._use_next_iterator() return self.__next__() def __iter__(self) -> Self: # type: ignore[valid-type] @@ -108,6 +128,10 @@ def reset(self) -> None: self._iterator_idx = 0 self._idx = 0 + def _use_next_iterator(self) -> None: + self._iterator_idx += 1 + self._idx = 0 + class _CombinationMode(TypedDict): fn: Callable[[List[int]], int] @@ -170,28 +194,28 @@ class CombinedLoader(Iterable): >>> combined_loader = CombinedLoader(iterables, 'max_size_cycle') >>> len(combined_loader) 3 - >>> for item in combined_loader: - ... print(item) + >>> for batch in combined_loader: + ... print(batch) {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} {'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])} >>> combined_loader = CombinedLoader(iterables, 'min_size') >>> len(combined_loader) 2 - >>> for item in combined_loader: - ... print(item) + >>> for batch in combined_loader: + ... print(batch) {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} >>> combined_loader = CombinedLoader(iterables, 'sequential') >>> len(combined_loader) 5 - >>> for item in combined_loader: - ... print(*item) - 0 tensor([0, 1, 2, 3]) - 1 tensor([4, 5]) - 0 tensor([0, 1, 2, 3, 4]) - 1 tensor([5, 6, 7, 8, 9]) - 2 tensor([10, 11, 12, 13, 14]) + >>> for batch, batch_idx, dataloader_idx in combined_loader: + ... print(f"{batch} {batch_idx=} {dataloader_idx=}") + tensor([0, 1, 2, 3]) batch_idx=0 dataloader_idx=0 + tensor([4, 5]) batch_idx=1 dataloader_idx=0 + tensor([0, 1, 2, 3, 4]) batch_idx=0 dataloader_idx=1 + tensor([5, 6, 7, 8, 9]) batch_idx=1 dataloader_idx=1 + tensor([10, 11, 12, 13, 14]) batch_idx=2 dataloader_idx=1 """ def __init__(self, iterables: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") -> None: diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 0f033641fe07e..778acef4c7284 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -957,8 +957,7 @@ def _run_sanity_check(self) -> None: ] # run eval step - with torch.no_grad(): - val_loop.run() + val_loop.run() call._call_callback_hooks(self, "on_sanity_check_end") @@ -1228,7 +1227,7 @@ def lr_scheduler_configs(self) -> List[LRSchedulerConfig]: @property def precision(self) -> _PRECISION_INPUT_STR: - return self.strategy.precision_plugin.precision + return self.strategy.precision_plugin.precision # type: ignore @property def scaler(self) -> Optional[Any]: diff --git a/tests/integrations_app/apps/idle_timeout/app.py b/tests/integrations_app/apps/idle_timeout/app.py index 31e0d7c124ab6..d33df0a616d58 100644 --- a/tests/integrations_app/apps/idle_timeout/app.py +++ b/tests/integrations_app/apps/idle_timeout/app.py @@ -2,7 +2,7 @@ from lightning.app import CloudCompute, LightningApp, LightningFlow, LightningWork from lightning.app.storage.path import _artifacts_path, _filesystem -from lightning.app.utilities.enum import WorkStageStatus, WorkStopReasons +from lightning.app.utilities.enum import WorkStageStatus class SourceFileWriterWork(LightningWork): @@ -35,22 +35,21 @@ def run(self): if self.work.counter == 0: self.work.run() - elif ( - self.work.status.stage == WorkStageStatus.STOPPED - and self.work.status.reason == WorkStopReasons.SIGTERM_SIGNAL_HANDLER - and self.make_check - ): - succeeded_status = self.work.statuses[-3] - stopped_status_pending = self.work.statuses[-2] - stopped_status_sigterm = self.work.statuses[-1] - assert succeeded_status.stage == WorkStageStatus.SUCCEEDED - assert stopped_status_pending.stage == WorkStageStatus.STOPPED - assert stopped_status_pending.reason == WorkStopReasons.PENDING - assert stopped_status_sigterm.stage == WorkStageStatus.STOPPED - assert stopped_status_sigterm.reason == WorkStopReasons.SIGTERM_SIGNAL_HANDLER + elif self.work.status.stage == WorkStageStatus.STOPPED and self.make_check: + succeeded_statuses = [status for status in self.work.statuses if status.stage == WorkStageStatus.SUCCEEDED] + # Ensure the work succeeded at some point + assert len(succeeded_statuses) > 0 + succeeded_status = succeeded_statuses[-1] + + stopped_statuses = [status for status in self.work.statuses if status.stage == WorkStageStatus.STOPPED] + + # We want to check that the work started shutting down withing the required timeframe, so we take the first + # status that has `stage == STOPPED`. + stopped_status = stopped_statuses[0] + # Note: Account for the controlplane, k8s, SIGTERM handler delays. - assert (stopped_status_pending.timestamp - succeeded_status.timestamp) < 20 - assert (stopped_status_sigterm.timestamp - stopped_status_pending.timestamp) < 120 + assert (stopped_status.timestamp - succeeded_status.timestamp) < 20 + fs = _filesystem() destination_path = _artifacts_path(self.work) / pathlib.Path(*self.work.path.resolve().parts[1:]) assert fs.exists(destination_path) diff --git a/tests/tests_app/utilities/test_network.py b/tests/tests_app/utilities/test_network.py index 1795d5d524966..f8cc25304f0ae 100644 --- a/tests/tests_app/utilities/test_network.py +++ b/tests/tests_app/utilities/test_network.py @@ -2,6 +2,7 @@ import pytest +from lightning.app.core import constants from lightning.app.utilities.network import find_free_network_port, LightningClient @@ -40,6 +41,9 @@ def test_find_free_network_port_cloudspace(_, patch_constants): # Check that all ports are unique assert len(ports) == num_ports + # Shouldn't use the APP_SERVER_PORT + assert constants.APP_SERVER_PORT not in ports + def test_lightning_client_retry_enabled(): diff --git a/tests/tests_fabric/strategies/test_registry.py b/tests/tests_fabric/strategies/test_registry.py index 6c636fdf9795b..07aee5ea91f0f 100644 --- a/tests/tests_fabric/strategies/test_registry.py +++ b/tests/tests_fabric/strategies/test_registry.py @@ -55,7 +55,6 @@ def test_available_strategies_in_registry(): "ddp_fork", "ddp_notebook", "single_tpu", - "tpu_spawn", "xla", "dp", } diff --git a/tests/tests_fabric/test_cli.py b/tests/tests_fabric/test_cli.py index df7832c12dd1c..051df16528540 100644 --- a/tests/tests_fabric/test_cli.py +++ b/tests/tests_fabric/test_cli.py @@ -120,7 +120,7 @@ def test_cli_env_vars_num_nodes(num_nodes, monkeypatch, fake_script): assert os.environ["LT_NUM_NODES"] == num_nodes -@pytest.mark.parametrize("precision", ["64-true", "32-true", "16-mixed", "bf16-mixed"]) +@pytest.mark.parametrize("precision", ["64-true", "64", "32-true", "32", "16-mixed", "bf16-mixed"]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_cli_env_vars_precision(precision, monkeypatch, fake_script): monkeypatch.setattr(torch.distributed, "run", Mock()) diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index f13d461312488..8296e2426f4db 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -447,16 +447,15 @@ def test_validate_precision_type(precision): ("16-mixed", "16-mixed", False), ("bf16", "bf16-mixed", True), ("bf16-mixed", "bf16-mixed", False), - (32, "32-true", True), - ("32", "32-true", True), + (32, "32-true", False), + ("32", "32-true", False), ("32-true", "32-true", False), - (64, "64-true", True), - ("64", "64-true", True), + (64, "64-true", False), + ("64", "64-true", False), ("64-true", "64-true", False), ], ) # mock cuda as available to not be limited by dtype and accelerator compatibility - this is tested elsewhere -@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) @mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1) @mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False) def test_precision_conversion(patch1, patch2, precision, expected_precision, should_warn): diff --git a/tests/tests_pytorch/accelerators/test_tpu.py b/tests/tests_pytorch/accelerators/test_tpu.py index 22af966eb7554..a2a4389142da8 100644 --- a/tests/tests_pytorch/accelerators/test_tpu.py +++ b/tests/tests_pytorch/accelerators/test_tpu.py @@ -27,7 +27,7 @@ from lightning.pytorch.accelerators.tpu import TPUAccelerator from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.plugins import PrecisionPlugin, TPUPrecisionPlugin, XLACheckpointIO -from lightning.pytorch.strategies import DDPStrategy, TPUSpawnStrategy +from lightning.pytorch.strategies import DDPStrategy, XLAStrategy from lightning.pytorch.utilities import find_shared_parameters from tests_pytorch.helpers.runif import RunIf from tests_pytorch.trainer.optimization.test_manual_optimization import assert_emtpy_grad @@ -94,7 +94,7 @@ def test_accelerator_tpu(accelerator, devices, tpu_available): trainer = Trainer(accelerator=accelerator, devices=devices) assert isinstance(trainer.accelerator, TPUAccelerator) - assert isinstance(trainer.strategy, TPUSpawnStrategy) + assert isinstance(trainer.strategy, XLAStrategy) assert trainer.num_devices == 8 @@ -177,15 +177,15 @@ def test_strategy_choice_tpu_str_ddp_spawn(tpu_available): @RunIf(skip_windows=True) -def test_strategy_choice_tpu_str_tpu_spawn_debug(tpu_available): - trainer = Trainer(strategy="tpu_spawn_debug", accelerator="tpu", devices=8) - assert isinstance(trainer.strategy, TPUSpawnStrategy) +def test_strategy_choice_tpu_str_xla_debug(tpu_available): + trainer = Trainer(strategy="xla_debug", accelerator="tpu", devices=8) + assert isinstance(trainer.strategy, XLAStrategy) @RunIf(tpu=True) def test_strategy_choice_tpu_strategy(): - trainer = Trainer(strategy=TPUSpawnStrategy(), accelerator="tpu", devices=8) - assert isinstance(trainer.strategy, TPUSpawnStrategy) + trainer = Trainer(strategy=XLAStrategy(), accelerator="tpu", devices=8) + assert isinstance(trainer.strategy, XLAStrategy) @RunIf(tpu=True) @@ -237,7 +237,7 @@ def forward(self, x): def test_tpu_invalid_raises(tpu_available): - strategy = TPUSpawnStrategy(accelerator=TPUAccelerator(), precision_plugin=PrecisionPlugin()) + strategy = XLAStrategy(accelerator=TPUAccelerator(), precision_plugin=PrecisionPlugin()) with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"): Trainer(strategy=strategy, devices=8) @@ -248,14 +248,14 @@ def test_tpu_invalid_raises(tpu_available): def test_tpu_invalid_raises_set_precision_with_strategy(tpu_available): accelerator = TPUAccelerator() - strategy = TPUSpawnStrategy(accelerator=accelerator, precision_plugin=PrecisionPlugin()) + strategy = XLAStrategy(accelerator=accelerator, precision_plugin=PrecisionPlugin()) with pytest.raises(ValueError, match="`TPUAccelerator` can only be used with a `TPUPrecisionPlugin`"): Trainer(strategy=strategy, devices=8) accelerator = TPUAccelerator() strategy = DDPStrategy(accelerator=accelerator, precision_plugin=TPUPrecisionPlugin()) with pytest.raises( - ValueError, match="The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `TPUSpawnStrategy" + ValueError, match="The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `XLAStrategy" ): Trainer(strategy=strategy, devices=8) @@ -267,11 +267,11 @@ def test_xla_checkpoint_plugin_being_default(tpu_available): @RunIf(tpu=True) -@patch("lightning.pytorch.strategies.tpu_spawn.TPUSpawnStrategy.root_device") +@patch("lightning.pytorch.strategies.xla.XLAStrategy.root_device") def test_xla_mp_device_dataloader_attribute(_, monkeypatch): dataset = RandomDataset(32, 64) dataloader = DataLoader(dataset) - strategy = TPUSpawnStrategy() + strategy = XLAStrategy() isinstance_return = True import torch_xla.distributed.parallel_loader as parallel_loader diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index 56b7bb8795ff8..d2723e67fa348 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -174,7 +174,7 @@ def mps_count_4(monkeypatch): @pytest.fixture(scope="function") def xla_available(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(lightning.pytorch.accelerators.tpu, "_XLA_AVAILABLE", True) - monkeypatch.setattr(lightning.pytorch.strategies.tpu_spawn, "_XLA_AVAILABLE", True) + monkeypatch.setattr(lightning.pytorch.strategies.xla, "_XLA_AVAILABLE", True) monkeypatch.setattr(lightning.pytorch.strategies.single_tpu, "_XLA_AVAILABLE", True) monkeypatch.setattr(lightning.pytorch.plugins.precision.tpu, "_XLA_AVAILABLE", True) monkeypatch.setattr(lightning.pytorch.strategies.launchers.xla, "_XLA_AVAILABLE", True) diff --git a/tests/tests_pytorch/loops/test_fetchers.py b/tests/tests_pytorch/loops/test_fetchers.py index a5c1462f7f4aa..07c6c1507c8d3 100644 --- a/tests/tests_pytorch/loops/test_fetchers.py +++ b/tests/tests_pytorch/loops/test_fetchers.py @@ -64,8 +64,9 @@ def generate(): # we can only know the last batch with sized iterables or when we prefetch is_last_batch = [False, False, prefetch_batches > 0 or dataset_cls is SizedDataset] - fetched = list(range(prefetch_batches + 1, 4)) - fetched += [3] * (3 - len(fetched)) + fetched = ( + [1, 2, 3] if dataset_cls is SizedDataset else [1, 2, 3, 3, 3, 3, 3][prefetch_batches : prefetch_batches + 3] + ) batches = [[1, 1], [2, 2], [3, 3]] if use_combined_loader else [1, 2, 3] expected = list(zip(fetched, batches, is_last_batch)) assert len(expected) == 3 diff --git a/tests/tests_pytorch/loops/test_prediction_loop.py b/tests/tests_pytorch/loops/test_prediction_loop.py index 5d1de82f6536b..1b5e05c502e5a 100644 --- a/tests/tests_pytorch/loops/test_prediction_loop.py +++ b/tests/tests_pytorch/loops/test_prediction_loop.py @@ -1,3 +1,19 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock +from unittest.mock import call + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel @@ -30,3 +46,20 @@ def predict_step(self, batch, batch_idx): predictions = trainer.predict(model, return_predictions=False) assert predictions is None assert trainer.predict_loop.predictions == [] + + +def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path): + """Tests that set_epoch is called on the dataloader's batch sampler (if any) during prediction.""" + model = BoringModel() + trainer = Trainer( + default_root_dir=tmp_path, + limit_predict_batches=1, + enable_model_summary=False, + enable_checkpointing=False, + logger=False, + ) + trainer.fit_loop.epoch_progress.current.processed = 2 + + with mock.patch("lightning.pytorch.overrides.distributed.IndexBatchSamplerWrapper.set_epoch") as set_epoch_mock: + trainer.predict(model) + assert set_epoch_mock.mock_calls == [call(2)] diff --git a/tests/tests_pytorch/models/test_tpu.py b/tests/tests_pytorch/models/test_tpu.py index 031fef8e72ce8..5685739c78837 100644 --- a/tests/tests_pytorch/models/test_tpu.py +++ b/tests/tests_pytorch/models/test_tpu.py @@ -24,7 +24,7 @@ from lightning.pytorch.accelerators import TPUAccelerator from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset -from lightning.pytorch.strategies import TPUSpawnStrategy +from lightning.pytorch.strategies import XLAStrategy from lightning.pytorch.strategies.launchers.xla import _XLALauncher from lightning.pytorch.trainer.connectors.logger_connector.result import _Sync from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -285,7 +285,7 @@ def wrap_launch_function(fn, strategy, *args, **kwargs): def xla_launch(fn): # TODO: the accelerator should be optional to just launch processes, but this requires lazy initialization accelerator = TPUAccelerator() - strategy = TPUSpawnStrategy(accelerator=accelerator, parallel_devices=list(range(8))) + strategy = XLAStrategy(accelerator=accelerator, parallel_devices=list(range(8))) launcher = _XLALauncher(strategy=strategy) wrapped = partial(wrap_launch_function, fn, strategy) return launcher.launch(wrapped, strategy) @@ -325,7 +325,7 @@ def teardown(self, stage): devices=8, limit_train_batches=0.4, limit_val_batches=0.4, - strategy=TPUSpawnStrategy(debug=True), + strategy=XLAStrategy(debug=True), ) model = DebugModel() @@ -359,6 +359,6 @@ def on_train_start(self): @RunIf(tpu=True) def test_device_type_when_tpu_strategy_passed(tmpdir): - trainer = Trainer(default_root_dir=tmpdir, strategy=TPUSpawnStrategy(), accelerator="tpu", devices=8) - assert isinstance(trainer.strategy, TPUSpawnStrategy) + trainer = Trainer(default_root_dir=tmpdir, strategy=XLAStrategy(), accelerator="tpu", devices=8) + assert isinstance(trainer.strategy, XLAStrategy) assert isinstance(trainer.accelerator, TPUAccelerator) diff --git a/tests/tests_pytorch/strategies/test_registry.py b/tests/tests_pytorch/strategies/test_registry.py index f5e7384ea1f2c..75b7b63957387 100644 --- a/tests/tests_pytorch/strategies/test_registry.py +++ b/tests/tests_pytorch/strategies/test_registry.py @@ -21,7 +21,7 @@ DeepSpeedStrategy, FSDPStrategy, StrategyRegistry, - TPUSpawnStrategy, + XLAStrategy, ) from tests_pytorch.helpers.runif import RunIf @@ -54,15 +54,15 @@ def test_deepspeed_strategy_registry_with_trainer(tmpdir, strategy): @RunIf(skip_windows=True) -def test_tpu_spawn_debug_strategy_registry(xla_available): - strategy = "tpu_spawn_debug" +def test_xla_debug_strategy_registry(xla_available): + strategy = "xla_debug" assert strategy in StrategyRegistry assert StrategyRegistry[strategy]["init_params"] == {"debug": True} - assert StrategyRegistry[strategy]["strategy"] == TPUSpawnStrategy + assert StrategyRegistry[strategy]["strategy"] == XLAStrategy trainer = Trainer(strategy=strategy) - assert isinstance(trainer.strategy, TPUSpawnStrategy) + assert isinstance(trainer.strategy, XLAStrategy) @RunIf(min_torch="1.12") diff --git a/tests/tests_pytorch/strategies/test_tpu_spawn.py b/tests/tests_pytorch/strategies/test_xla.py similarity index 57% rename from tests/tests_pytorch/strategies/test_tpu_spawn.py rename to tests/tests_pytorch/strategies/test_xla.py index 73f14fa318813..d7724464a5515 100644 --- a/tests/tests_pytorch/strategies/test_tpu_spawn.py +++ b/tests/tests_pytorch/strategies/test_xla.py @@ -21,8 +21,7 @@ from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset -from lightning.pytorch.strategies import TPUSpawnStrategy -from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.strategies import XLAStrategy from tests_pytorch.helpers.dataloaders import CustomNotImplementedErrorDataloader from tests_pytorch.helpers.runif import RunIf @@ -45,39 +44,10 @@ def predict_dataloader(self): _loader_no_len = CustomNotImplementedErrorDataloader(_loader) -@pytest.mark.parametrize( - "train_dataloaders, val_dataloaders, test_dataloaders, predict_dataloaders", - [ - (_loader_no_len, None, None, None), - (None, _loader_no_len, None, None), - (None, None, _loader_no_len, None), - (None, None, None, _loader_no_len), - (None, [_loader, _loader_no_len], None, None), - ], -) -def test_error_iterable_dataloaders_passed_to_fit( - xla_available, train_dataloaders, val_dataloaders, test_dataloaders, predict_dataloaders -): - """Test that the TPUSpawnStrategy identifies dataloaders with iterable datasets and fails early.""" - trainer = Trainer() - model = BoringModelNoDataloaders() - model.trainer = trainer - - trainer._data_connector.attach_dataloaders( - model, - train_dataloaders=train_dataloaders, - val_dataloaders=val_dataloaders, - test_dataloaders=test_dataloaders, - predict_dataloaders=predict_dataloaders, - ) - - with pytest.raises(MisconfigurationException, match="TPUs do not currently support"): - TPUSpawnStrategy(MagicMock()).connect(model) - - def test_error_process_iterable_dataloader(xla_available): - with pytest.raises(MisconfigurationException, match="TPUs do not currently support"): - TPUSpawnStrategy(MagicMock()).process_dataloader(_loader_no_len) + strategy = XLAStrategy(MagicMock()) + with pytest.raises(TypeError, match="TPUs do not currently support"): + strategy.process_dataloader(_loader_no_len) class BoringModelTPU(BoringModel): @@ -90,9 +60,9 @@ def on_train_start(self) -> None: @RunIf(tpu=True, standalone=True) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_model_tpu_one_core(): - """Tests if device/debug flag is set correctly when training and after teardown for TPUSpawnStrategy.""" + """Tests if device/debug flag is set correctly when training and after teardown for XLAStrategy.""" model = BoringModelTPU() - trainer = Trainer(accelerator="tpu", devices=1, fast_dev_run=True, strategy=TPUSpawnStrategy(debug=True)) - assert isinstance(trainer.strategy, TPUSpawnStrategy) + trainer = Trainer(accelerator="tpu", devices=1, fast_dev_run=True, strategy=XLAStrategy(debug=True)) + assert isinstance(trainer.strategy, XLAStrategy) trainer.fit(model) assert "PT_XLA_DEBUG" not in os.environ diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index b97fa36d64bac..009f57d8db3fc 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -601,7 +601,7 @@ def test_unsupported_tpu_choice(tpu_available): ): Trainer(accelerator="tpu", precision="64-true") - # if user didn't set strategy, AcceleratorConnector will choose the TPUSingleStrategy or TPUSpawnStrategy + # if user didn't set strategy, AcceleratorConnector will choose the TPUSingleStrategy or XLAStrategy with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUStrategy`"), pytest.warns( UserWarning, match=r"accelerator='tpu', precision=16-mixed\)` but AMP with fp16 is not supported" ): @@ -839,6 +839,7 @@ def get_defaults(cls): assert connector_default == trainer_defaults[name] +@RunIf(min_cuda_gpus=1) # trigger this test on our GPU pipeline, because we don't install the package on the CPU suite @pytest.mark.skipif(not package_available("lightning_colossalai"), reason="Requires Colossal AI Strategy") def test_colossalai_external_strategy(monkeypatch): with mock.patch( diff --git a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py index f3e5c6daf3b3d..740af109dad4a 100644 --- a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py +++ b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py @@ -142,7 +142,7 @@ def test_num_stepping_batches_with_tpu_single(): @RunIf(tpu=True) @mock.patch( - "lightning.pytorch.strategies.tpu_spawn.TPUSpawnStrategy.root_device", + "lightning.pytorch.strategies.xla.XLAStrategy.root_device", new_callable=PropertyMock, return_value=torch.device("xla:0"), ) diff --git a/tests/tests_pytorch/trainer/test_supporters.py b/tests/tests_pytorch/trainer/test_supporters.py index 01025975a248f..08af8ca7148e8 100644 --- a/tests/tests_pytorch/trainer/test_supporters.py +++ b/tests/tests_pytorch/trainer/test_supporters.py @@ -122,13 +122,14 @@ def test_combined_loader_modes(): combined_loader = CombinedLoader(iterables, "sequential") assert combined_loader._iterator is None assert len(combined_loader) == sum_len - for total_idx, (idx, item) in enumerate(combined_loader): + for total_idx, (item, batch_idx, dataloader_idx) in enumerate(combined_loader): assert isinstance(combined_loader._iterator, _Sequential) - assert isinstance(idx, int) + assert isinstance(batch_idx, int) assert isinstance(item, Tensor) assert idx == lengths[-1] - 1 assert total_idx == sum_len - 1 assert total_idx == len(combined_loader) - 1 + assert dataloader_idx == len(iterables) - 1 iterables = list(iterables.values()) @@ -156,13 +157,14 @@ def test_combined_loader_modes(): combined_loader = CombinedLoader(iterables, "sequential") assert combined_loader._iterator is None assert len(combined_loader) == sum_len - for total_idx, (idx, item) in enumerate(combined_loader): + for total_idx, (item, batch_idx, dataloader_idx) in enumerate(combined_loader): assert isinstance(combined_loader._iterator, _Sequential) - assert isinstance(idx, int) + assert isinstance(batch_idx, int) assert isinstance(item, Tensor) assert idx == lengths[-1] - 1 assert total_idx == sum_len - 1 assert total_idx == len(combined_loader) - 1 + assert dataloader_idx == len(iterables) - 1 def test_combined_loader_raises(): @@ -205,7 +207,6 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader has_break = False for idx, item in enumerate(combined_loader): assert isinstance(item, Sequence) - assert len(item) == 2 if use_multiple_dataloaders else 1 if not use_multiple_dataloaders and idx == 4: has_break = True break @@ -221,6 +222,27 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader assert idx == expected - 1 +@pytest.mark.parametrize( + ("limits", "expected"), + [ + (None, [("a", 0, 0), ("b", 1, 0), ("c", 2, 0), ("d", 0, 1), ("e", 1, 1)]), + ([1, 0], [("a", 0, 0)]), + ([0, float("inf")], [("d", 0, 1), ("e", 1, 1)]), + ([1, 1], [("a", 0, 0), ("d", 0, 1)]), + ], +) +def test_sequential_mode_limits(limits, expected): + iterable1 = ["a", "b", "c"] + iterable2 = ["d", "e"] + iterator = _Sequential([iterable1, iterable2], limits) + assert list(iterator) == expected + + +def test_sequential_mode_limits_raises(): + with pytest.raises(ValueError, match=r"number of limits \(0\) and number of iterables \(2\)"): + _Sequential([0, 1], []) + + @pytest.mark.parametrize("lengths", [[4, 6], [5, 5], [6, 4]]) def test_combined_loader_sequence_with_map_and_iterable(lengths): class MyIterableDataset(IterableDataset):