Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TPUSpawn + IterableDataset error message #6875

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@
import os
import re
import time
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING

import torch
import torch.multiprocessing as mp

from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
Expand All @@ -40,13 +42,50 @@
from omegaconf import DictConfig, ListConfig, OmegaConf


if TYPE_CHECKING:
from torch.nn import Module
from torch.utils.data import DataLoader


class TPUSpawnPlugin(DDPSpawnPlugin):

def __init__(self, parallel_devices: Optional[List[int]] = None, **kwargs: Dict[str, Any]) -> None:
super().__init__(parallel_devices, num_nodes=1, cluster_environment=None, sync_batchnorm=False)
self.tpu_local_core_rank = 0
self.start_method = None

@staticmethod
def _validate_dataloader(dataloaders: Union[List['DataLoader'], 'DataLoader']):
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]

for dataloader in dataloaders:
if not has_len(dataloader):
raise MisconfigurationException(
"TPUs do not currently support IterableDataset objects, the dataset must implement __len__."
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
" HINT: You can mock the length on your dataset to bypass this MisconfigurationException."
)

@staticmethod
def _validate_patched_dataloaders(model: 'Module') -> None:
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
"""Validate and fail fast if the dataloaders were passed directly to fit.
"""
if isinstance(model.train_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.train_dataloader.dataloader)

if isinstance(model.val_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.val_dataloader.dataloader)

if isinstance(model.test_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.test_dataloader.dataloader)

if isinstance(model.predict_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.predict_dataloader.dataloader)

def connect(self, model: 'Module') -> None:
TPUSpawnPlugin._validate_patched_dataloaders(model)
return super().connect(model)

def setup(self, model: torch.nn.Module) -> torch.nn.Module:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.create_mp_queue()
return self.model
Expand All @@ -64,7 +103,8 @@ def distributed_sampler_kwargs(self) -> dict:
def is_distributed(self):
return self.world_size != 1

def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> MpDeviceLoader:
def process_dataloader(self, dataloader: 'DataLoader') -> MpDeviceLoader:
TPUSpawnPlugin._validate_dataloader(dataloader)
device = xm.xla_device()
dataloader = MpDeviceLoader(dataloader, device)
return dataloader
Expand Down
74 changes: 74 additions & 0 deletions tests/plugins/test_tpu_spawn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock

import pytest
from torch.utils.data import DataLoader

from pytorch_lightning.plugins.training_type import TPUSpawnPlugin
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.dataloaders import CustomNotImplementedErrorDataloader


class BoringModelNoDataloaders(BoringModel):
def train_dataloader(self):
raise NotImplementedError

def val_dataloader(self):
raise NotImplementedError

def test_dataloader(self):
raise NotImplementedError

def predict_dataloader(self):
raise NotImplementedError


_loader = DataLoader(RandomDataset(32, 64))
_loader_no_len = CustomNotImplementedErrorDataloader(_loader)


@pytest.mark.parametrize(
"train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders",
[
(_loader_no_len, None, None, None),
(None, _loader_no_len, None, None),
(None, None, _loader_no_len, None),
(None, None, None, _loader_no_len),
(None, [_loader, _loader_no_len], None, None),
],
)
def test_error_patched_iterable_dataloaders(
tmpdir, train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
):
model = BoringModelNoDataloaders()
connector = DataConnector(MagicMock())

connector.attach_dataloaders(
model,
train_dataloader=train_dataloader,
val_dataloaders=val_dataloaders,
test_dataloaders=test_dataloaders,
predict_dataloaders=predict_dataloaders,
)

with pytest.raises(MisconfigurationException, match="TPUs do not currently support"):
TPUSpawnPlugin(MagicMock()).connect(model)


def test_error_process_iterable_dataloader(tmpdir):
with pytest.raises(MisconfigurationException, match="TPUs do not currently support"):
TPUSpawnPlugin(MagicMock()).process_dataloader(_loader_no_len)