Skip to content

Commit

Permalink
Run XLA's dataloader validation per dataloader (#16775)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Feb 16, 2023
1 parent 57f2f1c commit c9452df
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 80 deletions.
18 changes: 7 additions & 11 deletions src/lightning/fabric/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import io
import os
from typing import Any, Dict, List, Mapping, Optional, Sequence, TYPE_CHECKING, Union
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union

import torch
from torch import Tensor
Expand All @@ -30,7 +30,6 @@
from lightning.fabric.strategies import ParallelStrategy
from lightning.fabric.strategies.launchers.xla import _XLALauncher
from lightning.fabric.strategies.strategy import TBroadcast
from lightning.fabric.utilities.apply_func import apply_to_collection
from lightning.fabric.utilities.data import has_len
from lightning.fabric.utilities.rank_zero import rank_zero_only
from lightning.fabric.utilities.types import _PATH, ReduceOp
Expand Down Expand Up @@ -222,12 +221,9 @@ def _set_world_ranks(self) -> None:
rank_zero_only.rank = self.cluster_environment.global_rank()

@staticmethod
def _validate_dataloader(dataloaders: DataLoader) -> None:
def check_has_len(dataloader: DataLoader) -> None:
if not has_len(dataloader):
raise TypeError(
"TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`."
" HINT: You can mock the length on your dataset to bypass this MisconfigurationException."
)

apply_to_collection(dataloaders, dtype=object, wrong_dtype=(Sequence, Mapping), function=check_has_len)
def _validate_dataloader(dataloader: object) -> None:
if not has_len(dataloader):
raise TypeError(
"TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`."
" HINT: You can mock the length on your dataset to bypass this error."
)
5 changes: 2 additions & 3 deletions src/lightning/pytorch/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
import contextlib
import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, TypeVar, Union

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import DataLoader

import lightning.pytorch as pl
from lightning.fabric.plugins import CheckpointIO
Expand Down Expand Up @@ -405,7 +404,7 @@ def validation_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:
def test_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:
return output

def process_dataloader(self, dataloader: DataLoader) -> DataLoader:
def process_dataloader(self, dataloader: Iterable) -> Iterable:
"""Wraps the dataloader if necessary.
Args:
Expand Down
42 changes: 9 additions & 33 deletions src/lightning/pytorch/strategies/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
# limitations under the License.
import io
import os
from typing import Any, Dict, List, Mapping, Optional, Sequence, TYPE_CHECKING, Union
from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Union

import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader

import lightning.pytorch as pl
from lightning.fabric.accelerators.tpu import _XLA_AVAILABLE
Expand All @@ -34,12 +32,10 @@
from lightning.pytorch.strategies.ddp_spawn import DDPSpawnStrategy
from lightning.pytorch.strategies.launchers.xla import _XLALauncher
from lightning.pytorch.strategies.strategy import TBroadcast
from lightning.pytorch.trainer.connectors.data_connector import DataConnector
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities import find_shared_parameters, set_shared_parameters
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_zero_only
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
from lightning.pytorch.utilities.types import STEP_OUTPUT

if TYPE_CHECKING and _XLA_AVAILABLE:
from torch_xla.distributed.parallel_loader import MpDeviceLoader
Expand Down Expand Up @@ -98,34 +94,14 @@ def root_device(self) -> torch.device:
return xm.xla_device()

@staticmethod
def _validate_dataloader(dataloaders: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]) -> None:
def check_has_len(dataloader: DataLoader) -> None:
if not has_len(dataloader):
raise MisconfigurationException(
"TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`."
" HINT: You can mock the length on your dataset to bypass this MisconfigurationException."
)

apply_to_collection(dataloaders, dtype=object, wrong_dtype=(Sequence, Mapping), function=check_has_len)

@staticmethod
def _validate_patched_dataloaders(model: "pl.LightningModule") -> None:
"""Validate and fail fast if the dataloaders were passed directly to fit."""
connector: DataConnector = model.trainer._data_connector
sources = (
connector._train_dataloader_source,
connector._val_dataloader_source,
connector._test_dataloader_source,
connector._predict_dataloader_source,
)
for source in sources:
if not source.is_module():
assert source.instance is not None
assert not isinstance(source.instance, (pl.LightningModule, pl.LightningDataModule))
TPUSpawnStrategy._validate_dataloader(source.instance)
def _validate_dataloader(dataloader: object) -> None:
if not has_len(dataloader):
raise TypeError(
"TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`."
" HINT: You can mock the length on your dataset to bypass this error."
)

def connect(self, model: "pl.LightningModule") -> None:
TPUSpawnStrategy._validate_patched_dataloaders(model)
import torch_xla.distributed.xla_multiprocessing as xmp

self.wrapped_model = xmp.MpModelWrapper(_LightningModuleWrapperBase(model))
Expand Down Expand Up @@ -166,7 +142,7 @@ def is_distributed(self) -> bool:

return (xenv.HOST_WORLD_SIZE in os.environ) and self.world_size != 1

def process_dataloader(self, dataloader: DataLoader) -> "MpDeviceLoader":
def process_dataloader(self, dataloader: Iterable) -> "MpDeviceLoader":
TPUSpawnStrategy._validate_dataloader(dataloader)
from torch_xla.distributed.parallel_loader import MpDeviceLoader

Expand Down
36 changes: 3 additions & 33 deletions tests/tests_pytorch/strategies/test_tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.strategies import TPUSpawnStrategy
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.dataloaders import CustomNotImplementedErrorDataloader
from tests_pytorch.helpers.runif import RunIf

Expand All @@ -45,39 +44,10 @@ def predict_dataloader(self):
_loader_no_len = CustomNotImplementedErrorDataloader(_loader)


@pytest.mark.parametrize(
"train_dataloaders, val_dataloaders, test_dataloaders, predict_dataloaders",
[
(_loader_no_len, None, None, None),
(None, _loader_no_len, None, None),
(None, None, _loader_no_len, None),
(None, None, None, _loader_no_len),
(None, [_loader, _loader_no_len], None, None),
],
)
def test_error_iterable_dataloaders_passed_to_fit(
xla_available, train_dataloaders, val_dataloaders, test_dataloaders, predict_dataloaders
):
"""Test that the TPUSpawnStrategy identifies dataloaders with iterable datasets and fails early."""
trainer = Trainer()
model = BoringModelNoDataloaders()
model.trainer = trainer

trainer._data_connector.attach_dataloaders(
model,
train_dataloaders=train_dataloaders,
val_dataloaders=val_dataloaders,
test_dataloaders=test_dataloaders,
predict_dataloaders=predict_dataloaders,
)

with pytest.raises(MisconfigurationException, match="TPUs do not currently support"):
TPUSpawnStrategy(MagicMock()).connect(model)


def test_error_process_iterable_dataloader(xla_available):
with pytest.raises(MisconfigurationException, match="TPUs do not currently support"):
TPUSpawnStrategy(MagicMock()).process_dataloader(_loader_no_len)
strategy = TPUSpawnStrategy(MagicMock())
with pytest.raises(TypeError, match="TPUs do not currently support"):
strategy.process_dataloader(_loader_no_len)


class BoringModelTPU(BoringModel):
Expand Down

0 comments on commit c9452df

Please sign in to comment.