From 87bdb0996c365485bbf26c03b0a7e6a4c91c5484 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 12 May 2021 10:31:33 +0100 Subject: [PATCH] Add `available_data_sources` method and a warning when no data source is found (#282) * Add a warning when no data source is found * Add missing docstring * Update Co-authored-by: thomas chaton --- docs/source/general/data.rst | 1 + flash/data/data_module.py | 8 +++++++ flash/data/data_source.py | 4 ++-- flash/data/process.py | 28 +++++++++++++++++++++-- tests/data/test_process.py | 43 ++++++++++++++++++++++++++++++++++++ 5 files changed, 80 insertions(+), 4 deletions(-) diff --git a/docs/source/general/data.rst b/docs/source/general/data.rst index a3d46a794a..bbcdc095ee 100644 --- a/docs/source/general/data.rst +++ b/docs/source/general/data.rst @@ -336,6 +336,7 @@ __________ show_val_batch, show_test_batch, show_predict_batch, + available_data_sources, :exclude-members: autogenerate_dataset, diff --git a/flash/data/data_module.py b/flash/data/data_module.py index e5ac91e585..433aeead58 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -327,6 +327,14 @@ def postprocess(self) -> Postprocess: def data_pipeline(self) -> DataPipeline: return DataPipeline(self.data_source, self.preprocess, self.postprocess) + def available_data_sources(self) -> Sequence[str]: + """Get the list of available data source names for use with this :class:`~flash.data.data_module.DataModule`. + + Returns: + The list of data source names. + """ + return self.preprocess.available_data_sources() + @staticmethod def _split_train_val( train_dataset: Dataset, diff --git a/flash/data/data_source.py b/flash/data/data_source.py index 4238dbb514..bcf0a2d738 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -105,7 +105,7 @@ def load_data(self, Args: data: The data required to load the sequence or iterable of samples or sample metadata. dataset: Overriding methods can optionally include the dataset argument. Any attributes set on the dataset - (e.g. ``num_classes``) will also be set on the generated dataset. + (e.g. ``num_classes``) will also be set on the generated dataset. Returns: A sequence or iterable of samples or sample metadata to be used as inputs to @@ -130,7 +130,7 @@ def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) sample: An element (sample or sample metadata) from the output of a call to :meth:`~flash.data.data_source.DataSource.load_data`. dataset: Overriding methods can optionally include the dataset argument. Any attributes set on the dataset - (e.g. ``num_classes``) will also be set on the generated dataset. + (e.g. ``num_classes``) will also be set on the generated dataset. Returns: The loaded sample as a mapping with string keys (e.g. "input", "target") that can be processed by the diff --git a/flash/data/process.py b/flash/data/process.py index 11faddd2a8..d6502c0bfb 100644 --- a/flash/data/process.py +++ b/flash/data/process.py @@ -395,13 +395,37 @@ def per_batch_transform_on_device(self, batch: Any) -> Any: """ return self.current_transform(batch) - def data_source_of_name(self, data_source_name: str) -> Optional[DataSource]: + def available_data_sources(self) -> Sequence[str]: + """Get the list of available data source names for use with this :class:`~flash.data.process.Preprocess`. + + Returns: + The list of data source names. + """ + return list(self._data_sources.keys()) + + def data_source_of_name(self, data_source_name: str) -> DataSource: + """Get the :class:`~flash.data.data_source.DataSource` of the given name from the + :class:`~flash.data.process.Preprocess`. + + Args: + data_source_name: The name of the data source to look up. + + Returns: + The :class:`~flash.data.data_source.DataSource` of the given name. + + Raises: + MisconfigurationException: If the requested data source is not configured by this + :class:`~flash.data.process.Preprocess`. + """ if data_source_name == "default": data_source_name = self._default_data_source data_sources = self._data_sources if data_source_name in data_sources: return data_sources[data_source_name] - return None + raise MisconfigurationException( + f"No '{data_source_name}' data source is available for use with the {type(self)}. The available data " + f"sources are: {', '.join(self.available_data_sources())}." + ) class DefaultPreprocess(Preprocess): diff --git a/tests/data/test_process.py b/tests/data/test_process.py index 66df027b5b..8cf5de3dc2 100644 --- a/tests/data/test_process.py +++ b/tests/data/test_process.py @@ -16,11 +16,14 @@ import pytest import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import DataLoader from flash import Task, Trainer from flash.core.classification import Labels, LabelsState +from flash.data.data_module import DataModule from flash.data.data_pipeline import DataPipeline, DataPipelineState, DefaultPreprocess +from flash.data.data_source import DefaultDataSources from flash.data.process import Serializer, SerializerMapping from flash.data.properties import ProcessState, Properties @@ -131,3 +134,43 @@ def __init__(self): model = CustomModel.load_from_checkpoint(checkpoint_file) assert isinstance(model._data_pipeline_state, DataPipelineState) assert model._data_pipeline_state._state[LabelsState] == LabelsState(["a", "b"]) + + +class CustomPreprocess(DefaultPreprocess): + + def __init__(self): + super().__init__( + data_sources={ + "test": Mock(return_value="test"), + DefaultDataSources.TENSOR: Mock(return_value="tensor"), + }, + default_data_source="test", + ) + + +def test_data_source_of_name(): + + preprocess = CustomPreprocess() + + assert preprocess.data_source_of_name("test")() == "test" + assert preprocess.data_source_of_name(DefaultDataSources.TENSOR)() == "tensor" + assert preprocess.data_source_of_name("tensor")() == "tensor" + assert preprocess.data_source_of_name("default")() == "test" + + with pytest.raises(MisconfigurationException, match="available data sources are: test, tensor"): + preprocess.data_source_of_name("not available") + + +def test_available_data_sources(): + + preprocess = CustomPreprocess() + + assert DefaultDataSources.TENSOR in preprocess.available_data_sources() + assert "test" in preprocess.available_data_sources() + assert len(preprocess.available_data_sources()) == 2 + + data_module = DataModule(preprocess=preprocess) + + assert DefaultDataSources.TENSOR in data_module.available_data_sources() + assert "test" in data_module.available_data_sources() + assert len(data_module.available_data_sources()) == 2