Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add available_data_sources method and a warning when no data source is found #282

Merged
merged 5 commits into from
May 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ __________
show_val_batch,
show_test_batch,
show_predict_batch,
available_data_sources,
:exclude-members:
autogenerate_dataset,

Expand Down
8 changes: 8 additions & 0 deletions flash/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,14 @@ def postprocess(self) -> Postprocess:
def data_pipeline(self) -> DataPipeline:
return DataPipeline(self.data_source, self.preprocess, self.postprocess)

def available_data_sources(self) -> Sequence[str]:
"""Get the list of available data source names for use with this :class:`~flash.data.data_module.DataModule`.

Returns:
The list of data source names.
"""
return self.preprocess.available_data_sources()

@staticmethod
def _split_train_val(
train_dataset: Dataset,
Expand Down
4 changes: 2 additions & 2 deletions flash/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def load_data(self,
Args:
data: The data required to load the sequence or iterable of samples or sample metadata.
dataset: Overriding methods can optionally include the dataset argument. Any attributes set on the dataset
(e.g. ``num_classes``) will also be set on the generated dataset.
(e.g. ``num_classes``) will also be set on the generated dataset.

Returns:
A sequence or iterable of samples or sample metadata to be used as inputs to
Expand All @@ -130,7 +130,7 @@ def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None)
sample: An element (sample or sample metadata) from the output of a call to
:meth:`~flash.data.data_source.DataSource.load_data`.
dataset: Overriding methods can optionally include the dataset argument. Any attributes set on the dataset
(e.g. ``num_classes``) will also be set on the generated dataset.
(e.g. ``num_classes``) will also be set on the generated dataset.

Returns:
The loaded sample as a mapping with string keys (e.g. "input", "target") that can be processed by the
Expand Down
28 changes: 26 additions & 2 deletions flash/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,13 +395,37 @@ def per_batch_transform_on_device(self, batch: Any) -> Any:
"""
return self.current_transform(batch)

def data_source_of_name(self, data_source_name: str) -> Optional[DataSource]:
def available_data_sources(self) -> Sequence[str]:
"""Get the list of available data source names for use with this :class:`~flash.data.process.Preprocess`.

Returns:
The list of data source names.
"""
return list(self._data_sources.keys())

def data_source_of_name(self, data_source_name: str) -> DataSource:
"""Get the :class:`~flash.data.data_source.DataSource` of the given name from the
:class:`~flash.data.process.Preprocess`.

Args:
data_source_name: The name of the data source to look up.

Returns:
The :class:`~flash.data.data_source.DataSource` of the given name.

Raises:
MisconfigurationException: If the requested data source is not configured by this
:class:`~flash.data.process.Preprocess`.
"""
if data_source_name == "default":
data_source_name = self._default_data_source
data_sources = self._data_sources
if data_source_name in data_sources:
return data_sources[data_source_name]
return None
raise MisconfigurationException(
f"No '{data_source_name}' data source is available for use with the {type(self)}. The available data "
f"sources are: {', '.join(self.available_data_sources())}."
)


class DefaultPreprocess(Preprocess):
Expand Down
43 changes: 43 additions & 0 deletions tests/data/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@

import pytest
import torch
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import DataLoader

from flash import Task, Trainer
from flash.core.classification import Labels, LabelsState
from flash.data.data_module import DataModule
from flash.data.data_pipeline import DataPipeline, DataPipelineState, DefaultPreprocess
from flash.data.data_source import DefaultDataSources
from flash.data.process import Serializer, SerializerMapping
from flash.data.properties import ProcessState, Properties

Expand Down Expand Up @@ -131,3 +134,43 @@ def __init__(self):
model = CustomModel.load_from_checkpoint(checkpoint_file)
assert isinstance(model._data_pipeline_state, DataPipelineState)
assert model._data_pipeline_state._state[LabelsState] == LabelsState(["a", "b"])


class CustomPreprocess(DefaultPreprocess):

def __init__(self):
super().__init__(
data_sources={
"test": Mock(return_value="test"),
DefaultDataSources.TENSOR: Mock(return_value="tensor"),
},
default_data_source="test",
)


def test_data_source_of_name():

preprocess = CustomPreprocess()

assert preprocess.data_source_of_name("test")() == "test"
assert preprocess.data_source_of_name(DefaultDataSources.TENSOR)() == "tensor"
assert preprocess.data_source_of_name("tensor")() == "tensor"
assert preprocess.data_source_of_name("default")() == "test"

with pytest.raises(MisconfigurationException, match="available data sources are: test, tensor"):
preprocess.data_source_of_name("not available")


def test_available_data_sources():

preprocess = CustomPreprocess()

assert DefaultDataSources.TENSOR in preprocess.available_data_sources()
assert "test" in preprocess.available_data_sources()
assert len(preprocess.available_data_sources()) == 2

data_module = DataModule(preprocess=preprocess)

assert DefaultDataSources.TENSOR in data_module.available_data_sources()
assert "test" in data_module.available_data_sources()
assert len(data_module.available_data_sources()) == 2