From d0adc6127625797091eda2d5322b32e5a747d257 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 5 Nov 2021 12:54:59 +0000 Subject: [PATCH] Support PL 1.5.0 (#933) Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- CHANGELOG.md | 6 +++ flash/core/data/data_pipeline.py | 50 +++++++++++++++---- flash/core/data/new_data_module.py | 8 +-- flash/core/trainer.py | 17 +++++-- flash/core/utilities/compatibility.py | 20 ++++++++ flash/core/utilities/imports.py | 3 +- flash/image/classification/adapters.py | 12 ++++- .../classification/integrations/baal/loop.py | 44 ++++++++++++---- flash/image/embedding/vissl/hooks.py | 3 +- flash/video/classification/model.py | 5 +- requirements.txt | 2 +- tests/audio/classification/test_data.py | 7 --- .../test_data_model_integration.py | 2 +- tests/core/data/test_data_pipeline.py | 2 +- tests/core/utilities/test_lightning_cli.py | 18 +------ tests/image/classification/test_data.py | 5 -- tests/image/detection/test_data.py | 6 --- 17 files changed, 138 insertions(+), 72 deletions(-) create mode 100644 flash/core/utilities/compatibility.py diff --git a/CHANGELOG.md b/CHANGELOG.md index feb390c8c9..db2fcb4df2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,8 +18,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where validation metrics could be aggregated together with test metrics in some cases ([#900](https://github.com/PyTorchLightning/lightning-flash/pull/900)) + + - Fixed a bug where the latest versions of torchmetrics and Lightning Flash could not be installed together ([#902](https://github.com/PyTorchLightning/lightning-flash/pull/902)) + +- Fixed compatibility with PyTorch-Lightning 1.5 ([#933](https://github.com/PyTorchLightning/lightning-flash/pull/933)) + + ## [0.5.1] - 2021-10-26 ### Added diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index fd5ee8ef33..15f0afd035 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -18,24 +18,40 @@ from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, TYPE_CHECKING, Union import torch -from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from torch.utils.data import DataLoader, IterableDataset +import flash from flash.core.data.auto_dataset import IterableAutoDataset from flash.core.data.batch import _DeserializeProcessor, _Postprocessor, _Preprocessor, _Sequential, _SerializeProcessor from flash.core.data.data_source import DataSource from flash.core.data.process import DefaultPreprocess, Deserializer, Postprocess, Preprocess, Serializer from flash.core.data.properties import ProcessState from flash.core.data.utils import _POSTPROCESS_FUNCS, _PREPROCESS_FUNCS, _STAGES_PREFIX -from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_4_3 +from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_4_3, _PL_GREATER_EQUAL_1_5_0 from flash.core.utilities.stages import _RUNNING_STAGE_MAPPING, RunningStage +if not _PL_GREATER_EQUAL_1_5_0: + from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader + if TYPE_CHECKING: from flash.core.model import Task +class DataLoaderGetter: + """A utility class to be used when patching the ``{stage}_dataloader`` attribute of a LightningModule.""" + + def __init__(self, dataloader): + self.dataloader = dataloader + + # Dummy `__code__` attribute to trick is_overridden + self.__code__ = self.__call__.__code__ + + def __call__(self): + return self.dataloader + + class DataPipelineState: """A class to store and share all process states once a :class:`.DataPipeline` has been initialized.""" @@ -315,16 +331,34 @@ def _get_dataloader(model: "Task", loader_name: str) -> Tuple[DataLoader, str]: dataloader = getattr(model, loader_name) attr_name = loader_name - elif model.trainer and hasattr(model.trainer, "datamodule") and model.trainer.datamodule: - dataloader = getattr(model, f"trainer.datamodule.{loader_name}", None) + elif ( + model.trainer + and hasattr(model.trainer, "datamodule") + and model.trainer.datamodule + and is_overridden(loader_name, model.trainer.datamodule, flash.DataModule) + ): + dataloader = getattr(model.trainer.datamodule, loader_name, None) attr_name = f"trainer.datamodule.{loader_name}" + elif _PL_GREATER_EQUAL_1_5_0 and model.trainer is not None: + source = getattr(model.trainer._data_connector, f"_{loader_name}_source") + if not source.is_module(): + dataloader = source.dataloader() + attr_name = loader_name + + if dataloader is not None: + # Update source as wrapped loader will be attached to model + source.instance = model + source.name = loader_name + return dataloader, attr_name @staticmethod def _patch_dataloader(model: "Task", dataloader: Union[Callable, DataLoader], stage: RunningStage): if isinstance(dataloader, DataLoader): - if _PL_GREATER_EQUAL_1_4_3: + if _PL_GREATER_EQUAL_1_5_0: + dataloader = DataLoaderGetter(dataloader) + elif _PL_GREATER_EQUAL_1_4_3: dataloader = _PatchDataLoader(dataloader, _STAGES_PREFIX[stage]) dataloader.patch(model) else: @@ -369,7 +403,7 @@ def _attach_preprocess_to_model( if not dataloader: continue - if isinstance(dataloader, (_PatchDataLoader, Callable)): + if callable(dataloader): dataloader = dataloader() if dataloader is None: @@ -504,9 +538,7 @@ def _detach_preprocessing_from_model(self, model: "Task", stage: Optional[Runnin if not dataloader: continue - if isinstance(dataloader, _PatchDataLoader): - dataloader = dataloader() - elif isinstance(dataloader, Callable): + if callable(dataloader): dataloader = dataloader() if isinstance(dataloader, Sequence): diff --git a/flash/core/data/new_data_module.py b/flash/core/data/new_data_module.py index bd19b2057a..f9aa259df9 100644 --- a/flash/core/data/new_data_module.py +++ b/flash/core/data/new_data_module.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Tuple, Type, TYPE_CHECKING, Union +from typing import Any, Optional, Tuple, Type, Union import pytorch_lightning as pl import torch @@ -30,14 +30,8 @@ from flash.core.data.datasets import BaseDataset from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _FIFTYONE_AVAILABLE from flash.core.utilities.stages import RunningStage -if _FIFTYONE_AVAILABLE and TYPE_CHECKING: - from fiftyone.core.collections import SampleCollection -else: - SampleCollection = None - class DataModule(DataModule): """A basic DataModule class for all Flash tasks. This class includes references to a diff --git a/flash/core/trainer.py b/flash/core/trainer.py index 08799113c1..6656ed1087 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -25,11 +25,12 @@ from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.argparse import add_argparse_args, get_init_arguments_and_types, parse_env_variables from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_helpers import is_overridden from torch.utils.data import DataLoader import flash from flash.core.finetuning import _DEFAULTS_FINETUNE_STRATEGIES, instantiate_default_finetuning_callbacks -from flash.core.utilities.imports import _SERVE_AVAILABLE +from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_5_0, _SERVE_AVAILABLE def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): @@ -277,14 +278,24 @@ def request_dataloader( The dataloader """ model, stage, is_legacy = self._parse_request_dataloader_args(args, kwargs) + if is_legacy: self.call_hook(f"on_{stage}_dataloader") dataloader = getattr(model, f"{stage}_dataloader")() else: hook = f"{stage.dataloader_prefix}_dataloader" self.call_hook("on_" + hook, pl_module=model) - dataloader = self.call_hook(hook, pl_module=model) + + if is_overridden(hook, model): + dataloader = self.call_hook(hook, pl_module=model) + elif _PL_GREATER_EQUAL_1_5_0: + source = getattr(self._data_connector, f"_{stage.dataloader_prefix}_dataloader_source") + dataloader = source.dataloader() + if isinstance(dataloader, tuple): dataloader = list(dataloader) - self.accelerator.barrier("get_dataloaders") + if _PL_GREATER_EQUAL_1_5_0: + self.training_type_plugin.barrier("get_dataloaders") + else: + self.accelerator.barrier("get_dataloaders") return dataloader diff --git a/flash/core/utilities/compatibility.py b/flash/core/utilities/compatibility.py new file mode 100644 index 0000000000..c1656214be --- /dev/null +++ b/flash/core/utilities/compatibility.py @@ -0,0 +1,20 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning import Trainer + + +def accelerator_connector(trainer: Trainer): + if hasattr(trainer, "_accelerator_connector"): + return trainer._accelerator_connector + return trainer.accelerator_connector diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 53fdc66e26..0859763532 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -88,7 +88,7 @@ def _compare_version(package: str, op, version) -> bool: _PIL_AVAILABLE = _module_available("PIL") _OPEN3D_AVAILABLE = _module_available("open3d") _SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch") -_FASTFACE_AVAILABLE = _module_available("fastface") +_FASTFACE_AVAILABLE = _module_available("fastface") and _compare_version("pytorch_lightning", operator.lt, "1.5.0") _LIBROSA_AVAILABLE = _module_available("librosa") _TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter") _TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse") @@ -118,6 +118,7 @@ class Image: if Version: _TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0") _PL_GREATER_EQUAL_1_4_3 = _compare_version("pytorch_lightning", operator.ge, "1.4.3") + _PL_GREATER_EQUAL_1_5_0 = _compare_version("pytorch_lightning", operator.ge, "1.5.0") _TEXT_AVAILABLE = all( [ diff --git a/flash/image/classification/adapters.py b/flash/image/classification/adapters.py index a4b20a283e..ec141ecb54 100644 --- a/flash/image/classification/adapters.py +++ b/flash/image/classification/adapters.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools import inspect import os from collections import defaultdict @@ -31,6 +32,7 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.model import Task from flash.core.registry import FlashRegistry +from flash.core.utilities.compatibility import accelerator_connector from flash.core.utilities.imports import _LEARN2LEARN_AVAILABLE from flash.core.utilities.providers import _LEARN2LEARN from flash.core.utilities.url_error import catch_url_error @@ -183,9 +185,17 @@ def __init__( self.model = self.algorithm_cls(**algorithm_kwargs) + # Patch log to avoid error with learn2learn and PL 1.5 + self.model.log = functools.partial(self._patch_log, self.model.log) + # this algorithm requires a special treatment self._algorithm_has_validated = self.algorithm_cls != l2l.algorithms.LightningPrototypicalNetworks + def _patch_log(self, log, *args, on_step: Optional[bool] = None, on_epoch: Optional[bool] = None, **kwargs): + if not on_step and not on_epoch: + on_epoch = True + return log(*args, on_step=on_step, on_epoch=on_epoch, **kwargs) + def _default_transform(self, dataset, ways: int, shots: int, queries) -> List[Callable]: return [ l2l.data.transforms.FusedNWaysKShots(dataset, n=ways, k=shots + queries), @@ -268,7 +278,7 @@ def _convert_dataset( devices = 1 if isinstance(trainer.training_type_plugin, DataParallelPlugin): # when using DP, we need to sample n tasks, so it can splitted across multiple devices. - devices = trainer.accelerator_connector.devices + devices = accelerator_connector(trainer).devices dataset = TaskDataParallel(taskset, epoch_length=epoch_length, devices=devices, collate_fn=None) self.trainer.accumulated_grad_batches = self.meta_batch_size / devices diff --git a/flash/image/classification/integrations/baal/loop.py b/flash/image/classification/integrations/baal/loop.py index f71a4d41c0..ea94a9fb10 100644 --- a/flash/image/classification/integrations/baal/loop.py +++ b/flash/image/classification/integrations/baal/loop.py @@ -15,19 +15,24 @@ from typing import Any, Dict, Optional import torch +from pytorch_lightning import LightningModule from pytorch_lightning.loops import Loop from pytorch_lightning.loops.fit_loop import FitLoop -from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus +from pytorch_lightning.utilities.model_helpers import is_overridden import flash +from flash.core.data.data_pipeline import DataLoaderGetter from flash.core.data.utils import _STAGES_PREFIX -from flash.core.utilities.imports import requires +from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_5_0, requires from flash.core.utilities.stages import RunningStage from flash.image.classification.integrations.baal.data import ActiveLearningDataModule from flash.image.classification.integrations.baal.dropout import InferenceMCDropoutTask +if not _PL_GREATER_EQUAL_1_5_0: + from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader + class ActiveLearningLoop(Loop): @requires("baal") @@ -133,35 +138,52 @@ def __getattr__(self, key): return getattr(self.fit_loop, key) return self.__dict__[key] + def _connect(self, model: LightningModule): + if _PL_GREATER_EQUAL_1_5_0: + self.trainer.training_type_plugin.connect(model) + else: + self.trainer.accelerator.connect(model) + def _reset_fitting(self): self.trainer.state.fn = TrainerFn.FITTING self.trainer.training = True self.trainer.lightning_module.on_train_dataloader() - self.trainer.accelerator.connect(self._lightning_module) + self._connect(self._lightning_module) self.fit_loop.epoch_progress = Progress() def _reset_predicting(self): self.trainer.state.fn = TrainerFn.PREDICTING self.trainer.predicting = True self.trainer.lightning_module.on_predict_dataloader() - self.trainer.accelerator.connect(self.inference_model) + self._connect(self.inference_model) def _reset_testing(self): self.trainer.state.fn = TrainerFn.TESTING self.trainer.state.status = TrainerStatus.RUNNING self.trainer.testing = True self.trainer.lightning_module.on_test_dataloader() - self.trainer.accelerator.connect(self._lightning_module) + self._connect(self._lightning_module) def _reset_dataloader_for_stage(self, running_state: RunningStage): dataloader_name = f"{_STAGES_PREFIX[running_state]}_dataloader" # If the dataloader exists, we reset it. - dataloader = getattr(self.trainer.datamodule, dataloader_name, None) + dataloader = ( + getattr(self.trainer.datamodule, dataloader_name) + if is_overridden(dataloader_name, self.trainer.datamodule) + else None + ) if dataloader: - setattr( - self.trainer.lightning_module, - dataloader_name, - _PatchDataLoader(dataloader(), running_state), - ) + if _PL_GREATER_EQUAL_1_5_0: + setattr( + self.trainer.lightning_module, + dataloader_name, + DataLoaderGetter(dataloader()), + ) + else: + setattr( + self.trainer.lightning_module, + dataloader_name, + _PatchDataLoader(dataloader(), running_state), + ) setattr(self.trainer, dataloader_name, None) getattr(self.trainer, f"reset_{dataloader_name}")(self.trainer.lightning_module) diff --git a/flash/image/embedding/vissl/hooks.py b/flash/image/embedding/vissl/hooks.py index bd9931d886..d9e7369973 100644 --- a/flash/image/embedding/vissl/hooks.py +++ b/flash/image/embedding/vissl/hooks.py @@ -17,6 +17,7 @@ from pytorch_lightning.core.hooks import ModelHooks import flash +from flash.core.utilities.compatibility import accelerator_connector from flash.core.utilities.imports import _VISSL_AVAILABLE if _VISSL_AVAILABLE: @@ -48,7 +49,7 @@ def on_start(self, task: "flash.image.embedding.vissl.adapter.MockVISSLTask") -> # get around vissl distributed training by setting MockTask flags num_nodes = lightning_module.trainer.num_nodes - accelerators_ids = lightning_module.trainer.accelerator_connector.parallel_device_ids + accelerators_ids = accelerator_connector(lightning_module.trainer).parallel_device_ids accelerator_per_node = len(accelerators_ids) if accelerators_ids is not None else 1 task.world_size = num_nodes * accelerator_per_node diff --git a/flash/video/classification/model.py b/flash/video/classification/model.py index bddf95f75c..f70c913f54 100644 --- a/flash/video/classification/model.py +++ b/flash/video/classification/model.py @@ -29,6 +29,7 @@ from flash.core.classification import ClassificationTask, Labels from flash.core.data.data_source import DefaultDataKeys from flash.core.registry import FlashRegistry +from flash.core.utilities.compatibility import accelerator_connector from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE from flash.core.utilities.providers import _PYTORCHVIDEO from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE @@ -146,13 +147,13 @@ def __init__( ) def on_train_start(self) -> None: - if self.trainer.accelerator_connector.is_distributed: + if accelerator_connector(self.trainer).is_distributed: encoded_dataset = self.trainer.train_dataloader.loaders.dataset.dataset encoded_dataset._video_sampler = DistributedSampler(encoded_dataset._labeled_videos) super().on_train_start() def on_train_epoch_start(self) -> None: - if self.trainer.accelerator_connector.is_distributed: + if accelerator_connector(self.trainer).is_distributed: encoded_dataset = self.trainer.train_dataloader.loaders.dataset.dataset encoded_dataset._video_sampler.set_epoch(self.trainer.current_epoch) super().on_train_epoch_start() diff --git a/requirements.txt b/requirements.txt index 4ec374e51c..e29b34f8ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ packaging numpy torch>=1.7.1 torchmetrics>=0.4.0,!=0.5.1 -pytorch-lightning==1.4.9 +pytorch-lightning>=1.4.0 pyDeprecate pandas<1.3.0 jsonargparse[signatures]>=3.17.0 diff --git a/tests/audio/classification/test_data.py b/tests/audio/classification/test_data.py index b96034abd8..44010e7af1 100644 --- a/tests/audio/classification/test_data.py +++ b/tests/audio/classification/test_data.py @@ -58,8 +58,6 @@ def test_from_filepaths_smoke(tmpdir): num_workers=0, ) assert spectrograms_data.train_dataloader() is not None - assert spectrograms_data.val_dataloader() is None - assert spectrograms_data.test_dataloader() is None data = next(iter(spectrograms_data.train_dataloader())) imgs, labels = data["input"], data["target"] @@ -130,8 +128,6 @@ def test_from_filepaths_numpy(tmpdir): num_workers=0, ) assert spectrograms_data.train_dataloader() is not None - assert spectrograms_data.val_dataloader() is None - assert spectrograms_data.test_dataloader() is None data = next(iter(spectrograms_data.train_dataloader())) imgs, labels = data["input"], data["target"] @@ -323,9 +319,6 @@ def test_from_folders_only_train(tmpdir): assert imgs.shape == (1, 3, 128, 128) assert labels.shape == (1,) - assert spectrograms_data.val_dataloader() is None - assert spectrograms_data.test_dataloader() is None - @pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") def test_from_folders_train_val(tmpdir): diff --git a/tests/audio/speech_recognition/test_data_model_integration.py b/tests/audio/speech_recognition/test_data_model_integration.py index eda3ac86b3..34d63fd22a 100644 --- a/tests/audio/speech_recognition/test_data_model_integration.py +++ b/tests/audio/speech_recognition/test_data_model_integration.py @@ -16,9 +16,9 @@ from pathlib import Path import pytest -from pytorch_lightning import Trainer import flash +from flash import Trainer from flash.audio import SpeechRecognition, SpeechRecognitionData from tests.helpers.utils import _AUDIO_TESTING diff --git a/tests/core/data/test_data_pipeline.py b/tests/core/data/test_data_pipeline.py index d5853e19a9..68e48b546a 100644 --- a/tests/core/data/test_data_pipeline.py +++ b/tests/core/data/test_data_pipeline.py @@ -18,12 +18,12 @@ import numpy as np import pytest import torch -from pytorch_lightning import Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor, tensor from torch.utils.data import DataLoader from torch.utils.data._utils.collate import default_collate +from flash import Trainer from flash.core.data.auto_dataset import IterableAutoDataset from flash.core.data.batch import _Postprocessor, _Preprocessor from flash.core.data.data_module import DataModule diff --git a/tests/core/utilities/test_lightning_cli.py b/tests/core/utilities/test_lightning_cli.py index d8fc9982ab..b5a23a0eee 100644 --- a/tests/core/utilities/test_lightning_cli.py +++ b/tests/core/utilities/test_lightning_cli.py @@ -19,6 +19,7 @@ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.plugins.environments import SLURMEnvironment +from flash.core.utilities.compatibility import accelerator_connector from flash.core.utilities.imports import _TORCHVISION_AVAILABLE from flash.core.utilities.lightning_cli import ( instantiate_class, @@ -83,21 +84,6 @@ def test_add_argparse_args_redefined(cli_args): ("--limit_train_batches=100", dict(limit_train_batches=100)), ("--limit_train_batches 0.8", dict(limit_train_batches=0.8)), ("--weights_summary=null", dict(weights_summary=None)), - ( - "", - dict( - # These parameters are marked as Optional[...] in Trainer.__init__, - # with None as default. They should not be changed by the argparse - # interface. - min_steps=None, - max_steps=None, - log_gpu_memory=None, - distributed_backend=None, - weights_save_path=None, - resume_from_checkpoint=None, - profiler=None, - ), - ), ], ) def test_parse_args_parsing(cli_args, expected): @@ -283,7 +269,7 @@ def test_lightning_cli_args_cluster_environments(tmpdir): class TestModel(BoringModel): def on_fit_start(self): # Ensure SLURMEnvironment is set, instead of default LightningEnvironment - assert isinstance(self.trainer.accelerator_connector._cluster_environment, SLURMEnvironment) + assert isinstance(accelerator_connector(self.trainer)._cluster_environment, SLURMEnvironment) self.trainer.ran_asserts = True with mock.patch("sys.argv", ["any.py", f"--trainer.plugins={json.dumps(plugins)}"]): diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index c7773b7377..4e2a3c79ab 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -80,8 +80,6 @@ def test_from_filepaths_smoke(tmpdir): num_workers=0, ) assert img_data.train_dataloader() is not None - assert img_data.val_dataloader() is None - assert img_data.test_dataloader() is None data = next(iter(img_data.train_dataloader())) imgs, labels = data["input"], data["target"] @@ -275,9 +273,6 @@ def test_from_folders_only_train(tmpdir): assert imgs.shape == (1, 3, 196, 196) assert labels.shape == (1,) - assert img_data.val_dataloader() is None - assert img_data.test_dataloader() is None - @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_from_folders_train_val(tmpdir): diff --git a/tests/image/detection/test_data.py b/tests/image/detection/test_data.py index 50ce9fb196..3f8d700704 100644 --- a/tests/image/detection/test_data.py +++ b/tests/image/detection/test_data.py @@ -156,9 +156,6 @@ def test_image_detector_data_from_coco(tmpdir): sample = data[0] assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) - assert datamodule.val_dataloader() is None - assert datamodule.test_dataloader() is None - datamodule = ObjectDetectionData.from_coco( train_folder=train_folder, train_ann_file=coco_ann_path, @@ -193,9 +190,6 @@ def test_image_detector_data_from_fiftyone(tmpdir): sample = data[0] assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3) - assert datamodule.val_dataloader() is None - assert datamodule.test_dataloader() is None - datamodule = ObjectDetectionData.from_fiftyone( train_dataset=train_dataset, val_dataset=train_dataset,