Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Add available_data_sources method and a warning when no data source is found #282

Merged
merged 5 commits into from
May 12, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ __________
show_val_batch,
show_test_batch,
show_predict_batch,
available_data_sources,
:exclude-members:
autogenerate_dataset,

Expand Down
8 changes: 8 additions & 0 deletions flash/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,14 @@ def postprocess(self) -> Postprocess:
def data_pipeline(self) -> DataPipeline:
return DataPipeline(self.data_source, self.preprocess, self.postprocess)

def available_data_sources(self) -> Sequence[str]:
"""Get the list of available data source names for use with this :class:`~flash.data.data_module.DataModule`.

Returns:
The list of data source names.
"""
return self.preprocess.available_data_sources()

@staticmethod
def _split_train_val(
train_dataset: Dataset,
Expand Down
4 changes: 2 additions & 2 deletions flash/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def load_data(self,
Args:
data: The data required to load the sequence or iterable of samples or sample metadata.
dataset: Overriding methods can optionally include the dataset argument. Any attributes set on the dataset
(e.g. ``num_classes``) will also be set on the generated dataset.
(e.g. ``num_classes``) will also be set on the generated dataset.

Returns:
A sequence or iterable of samples or sample metadata to be used as inputs to
Expand All @@ -130,7 +130,7 @@ def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None)
sample: An element (sample or sample metadata) from the output of a call to
:meth:`~flash.data.data_source.DataSource.load_data`.
dataset: Overriding methods can optionally include the dataset argument. Any attributes set on the dataset
(e.g. ``num_classes``) will also be set on the generated dataset.
(e.g. ``num_classes``) will also be set on the generated dataset.

Returns:
The loaded sample as a mapping with string keys (e.g. "input", "target") that can be processed by the
Expand Down
28 changes: 26 additions & 2 deletions flash/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,13 +395,37 @@ def per_batch_transform_on_device(self, batch: Any) -> Any:
"""
return self.current_transform(batch)

def data_source_of_name(self, data_source_name: str) -> Optional[DataSource]:
def available_data_sources(self) -> Sequence[str]:
"""Get the list of available data source names for use with this :class:`~flash.data.process.Preprocess`.

Returns:
The list of data source names.
"""
return list(self._data_sources.keys())

def data_source_of_name(self, data_source_name: str) -> DataSource:
"""Get the :class:`~flash.data.data_source.DataSource` of the given name from the
:class:`~flash.data.process.Preprocess`.

Args:
data_source_name: The name of the data source to look up.

Returns:
The :class:`~flash.data.data_source.DataSource` of the given name.

Raises:
MisconfigurationException: If the requested data source is not configured by this
:class:`~flash.data.process.Preprocess`.
"""
if data_source_name == "default":
data_source_name = self._default_data_source
data_sources = self._data_sources
if data_source_name in data_sources:
return data_sources[data_source_name]
return None
raise MisconfigurationException(
f"No '{data_source_name}' data source is available for use with the {type(self)}. The available data "
f"sources are: {', '.join(self.available_data_sources())}."
)


class DefaultPreprocess(Preprocess):
Expand Down
43 changes: 43 additions & 0 deletions tests/data/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@

import pytest
import torch
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import DataLoader

from flash import Task, Trainer
from flash.core.classification import Labels, LabelsState
from flash.data.data_module import DataModule
from flash.data.data_pipeline import DataPipeline, DataPipelineState, DefaultPreprocess
from flash.data.data_source import DefaultDataSources
from flash.data.process import Serializer, SerializerMapping
from flash.data.properties import ProcessState, Properties

Expand Down Expand Up @@ -131,3 +134,43 @@ def __init__(self):
model = CustomModel.load_from_checkpoint(checkpoint_file)
assert isinstance(model._data_pipeline_state, DataPipelineState)
assert model._data_pipeline_state._state[LabelsState] == LabelsState(["a", "b"])


class CustomPreprocess(DefaultPreprocess):

def __init__(self):
super().__init__(
data_sources={
"test": Mock(return_value="test"),
DefaultDataSources.TENSOR: Mock(return_value="tensor"),
},
default_data_source="test",
)


def test_data_source_of_name():

preprocess = CustomPreprocess()

assert preprocess.data_source_of_name("test")() == "test"
assert preprocess.data_source_of_name(DefaultDataSources.TENSOR)() == "tensor"
assert preprocess.data_source_of_name("tensor")() == "tensor"
assert preprocess.data_source_of_name("default")() == "test"

with pytest.raises(MisconfigurationException, match="available data sources are: test, tensor"):
preprocess.data_source_of_name("not available")


def test_available_data_sources():

preprocess = CustomPreprocess()

assert DefaultDataSources.TENSOR in preprocess.available_data_sources()
assert "test" in preprocess.available_data_sources()
assert len(preprocess.available_data_sources()) == 2

data_module = DataModule(preprocess=preprocess)

assert DefaultDataSources.TENSOR in data_module.available_data_sources()
assert "test" in data_module.available_data_sources()
assert len(data_module.available_data_sources()) == 2