diff --git a/src/lightning/fabric/strategies/xla.py b/src/lightning/fabric/strategies/xla.py index ffaec51891412..a04fc0c1055f8 100644 --- a/src/lightning/fabric/strategies/xla.py +++ b/src/lightning/fabric/strategies/xla.py @@ -13,7 +13,7 @@ # limitations under the License. import io import os -from typing import Any, Dict, List, Mapping, Optional, Sequence, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union import torch from torch import Tensor @@ -30,7 +30,6 @@ from lightning.fabric.strategies import ParallelStrategy from lightning.fabric.strategies.launchers.xla import _XLALauncher from lightning.fabric.strategies.strategy import TBroadcast -from lightning.fabric.utilities.apply_func import apply_to_collection from lightning.fabric.utilities.data import has_len from lightning.fabric.utilities.rank_zero import rank_zero_only from lightning.fabric.utilities.types import _PATH, ReduceOp @@ -222,12 +221,9 @@ def _set_world_ranks(self) -> None: rank_zero_only.rank = self.cluster_environment.global_rank() @staticmethod - def _validate_dataloader(dataloaders: DataLoader) -> None: - def check_has_len(dataloader: DataLoader) -> None: - if not has_len(dataloader): - raise TypeError( - "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`." - " HINT: You can mock the length on your dataset to bypass this MisconfigurationException." - ) - - apply_to_collection(dataloaders, dtype=object, wrong_dtype=(Sequence, Mapping), function=check_has_len) + def _validate_dataloader(dataloader: object) -> None: + if not has_len(dataloader): + raise TypeError( + "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`." + " HINT: You can mock the length on your dataset to bypass this error." + ) diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index bb1ce96d62ebe..b3a2feb030343 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -14,13 +14,12 @@ import contextlib import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, TypeVar, Union import torch from torch import Tensor from torch.nn import Module from torch.optim import Optimizer -from torch.utils.data import DataLoader import lightning.pytorch as pl from lightning.fabric.plugins import CheckpointIO @@ -405,7 +404,7 @@ def validation_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: def test_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: return output - def process_dataloader(self, dataloader: DataLoader) -> DataLoader: + def process_dataloader(self, dataloader: Iterable) -> Iterable: """Wraps the dataloader if necessary. Args: diff --git a/src/lightning/pytorch/strategies/tpu_spawn.py b/src/lightning/pytorch/strategies/tpu_spawn.py index bcb5d669b7186..757bfa3c99fd3 100644 --- a/src/lightning/pytorch/strategies/tpu_spawn.py +++ b/src/lightning/pytorch/strategies/tpu_spawn.py @@ -13,13 +13,11 @@ # limitations under the License. import io import os -from typing import Any, Dict, List, Mapping, Optional, Sequence, TYPE_CHECKING, Union +from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Union import torch -from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor from torch.nn import Module -from torch.utils.data import DataLoader import lightning.pytorch as pl from lightning.fabric.accelerators.tpu import _XLA_AVAILABLE @@ -34,12 +32,10 @@ from lightning.pytorch.strategies.ddp_spawn import DDPSpawnStrategy from lightning.pytorch.strategies.launchers.xla import _XLALauncher from lightning.pytorch.strategies.strategy import TBroadcast -from lightning.pytorch.trainer.connectors.data_connector import DataConnector from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities import find_shared_parameters, set_shared_parameters -from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_zero_only -from lightning.pytorch.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS +from lightning.pytorch.utilities.types import STEP_OUTPUT if TYPE_CHECKING and _XLA_AVAILABLE: from torch_xla.distributed.parallel_loader import MpDeviceLoader @@ -98,34 +94,14 @@ def root_device(self) -> torch.device: return xm.xla_device() @staticmethod - def _validate_dataloader(dataloaders: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]) -> None: - def check_has_len(dataloader: DataLoader) -> None: - if not has_len(dataloader): - raise MisconfigurationException( - "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`." - " HINT: You can mock the length on your dataset to bypass this MisconfigurationException." - ) - - apply_to_collection(dataloaders, dtype=object, wrong_dtype=(Sequence, Mapping), function=check_has_len) - - @staticmethod - def _validate_patched_dataloaders(model: "pl.LightningModule") -> None: - """Validate and fail fast if the dataloaders were passed directly to fit.""" - connector: DataConnector = model.trainer._data_connector - sources = ( - connector._train_dataloader_source, - connector._val_dataloader_source, - connector._test_dataloader_source, - connector._predict_dataloader_source, - ) - for source in sources: - if not source.is_module(): - assert source.instance is not None - assert not isinstance(source.instance, (pl.LightningModule, pl.LightningDataModule)) - TPUSpawnStrategy._validate_dataloader(source.instance) + def _validate_dataloader(dataloader: object) -> None: + if not has_len(dataloader): + raise TypeError( + "TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`." + " HINT: You can mock the length on your dataset to bypass this error." + ) def connect(self, model: "pl.LightningModule") -> None: - TPUSpawnStrategy._validate_patched_dataloaders(model) import torch_xla.distributed.xla_multiprocessing as xmp self.wrapped_model = xmp.MpModelWrapper(_LightningModuleWrapperBase(model)) @@ -166,7 +142,7 @@ def is_distributed(self) -> bool: return (xenv.HOST_WORLD_SIZE in os.environ) and self.world_size != 1 - def process_dataloader(self, dataloader: DataLoader) -> "MpDeviceLoader": + def process_dataloader(self, dataloader: Iterable) -> "MpDeviceLoader": TPUSpawnStrategy._validate_dataloader(dataloader) from torch_xla.distributed.parallel_loader import MpDeviceLoader diff --git a/tests/tests_pytorch/strategies/test_tpu_spawn.py b/tests/tests_pytorch/strategies/test_tpu_spawn.py index 73f14fa318813..ddc3c2c76b577 100644 --- a/tests/tests_pytorch/strategies/test_tpu_spawn.py +++ b/tests/tests_pytorch/strategies/test_tpu_spawn.py @@ -22,7 +22,6 @@ from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.strategies import TPUSpawnStrategy -from lightning.pytorch.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.dataloaders import CustomNotImplementedErrorDataloader from tests_pytorch.helpers.runif import RunIf @@ -45,39 +44,10 @@ def predict_dataloader(self): _loader_no_len = CustomNotImplementedErrorDataloader(_loader) -@pytest.mark.parametrize( - "train_dataloaders, val_dataloaders, test_dataloaders, predict_dataloaders", - [ - (_loader_no_len, None, None, None), - (None, _loader_no_len, None, None), - (None, None, _loader_no_len, None), - (None, None, None, _loader_no_len), - (None, [_loader, _loader_no_len], None, None), - ], -) -def test_error_iterable_dataloaders_passed_to_fit( - xla_available, train_dataloaders, val_dataloaders, test_dataloaders, predict_dataloaders -): - """Test that the TPUSpawnStrategy identifies dataloaders with iterable datasets and fails early.""" - trainer = Trainer() - model = BoringModelNoDataloaders() - model.trainer = trainer - - trainer._data_connector.attach_dataloaders( - model, - train_dataloaders=train_dataloaders, - val_dataloaders=val_dataloaders, - test_dataloaders=test_dataloaders, - predict_dataloaders=predict_dataloaders, - ) - - with pytest.raises(MisconfigurationException, match="TPUs do not currently support"): - TPUSpawnStrategy(MagicMock()).connect(model) - - def test_error_process_iterable_dataloader(xla_available): - with pytest.raises(MisconfigurationException, match="TPUs do not currently support"): - TPUSpawnStrategy(MagicMock()).process_dataloader(_loader_no_len) + strategy = TPUSpawnStrategy(MagicMock()) + with pytest.raises(TypeError, match="TPUs do not currently support"): + strategy.process_dataloader(_loader_no_len) class BoringModelTPU(BoringModel):