From 1380c53e43985b2d264de82db54d024ebf678f29 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 12 May 2021 15:27:19 +0100 Subject: [PATCH] Swap PATHS for FOLDERS and FILES and change TENSOR -> TENSORS (#289) * Swap PATHS for FOLDERS and FILES, and change TENSOR -> TENSORS * Fix a test * Fixes * Fixes --- docs/source/general/data.rst | 7 ++++--- flash/data/data_module.py | 12 ++++++------ flash/data/data_source.py | 5 +++-- flash/video/classification/data.py | 12 +++++++++--- flash/vision/classification/data.py | 7 ++++--- flash/vision/detection/data.py | 5 +++-- flash/vision/embedding/model.py | 2 +- flash/vision/segmentation/data.py | 9 +++++---- tests/data/test_process.py | 10 +++++----- tests/vision/segmentation/test_model.py | 2 +- 10 files changed, 41 insertions(+), 30 deletions(-) diff --git a/docs/source/general/data.rst b/docs/source/general/data.rst index bbcdc095ee..0aeba056f2 100644 --- a/docs/source/general/data.rst +++ b/docs/source/general/data.rst @@ -222,11 +222,12 @@ Example:: test_transform=test_transform, predict_transform=predict_transform, data_sources={ - DefaultDataSources.PATHS: ImagePathsDataSource(), + DefaultDataSources.FILES: ImagePathsDataSource(), + DefaultDataSources.FOLDERS: ImagePathsDataSource(), DefaultDataSources.NUMPY: ImageNumpyDataSource(), - DefaultDataSources.TENSOR: ImageTensorDataSource(), + DefaultDataSources.TENSORS: ImageTensorDataSource(), }, - default_data_source=DefaultDataSources.PATHS, + default_data_source=DefaultDataSources.FILES, ) def get_state_dict(self) -> Dict[str, Any]: diff --git a/flash/data/data_module.py b/flash/data/data_module.py index 433aeead58..af325444ec 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -415,7 +415,7 @@ def from_data_source( Examples:: data_module = DataModule.from_data_source( - DefaultDataSources.PATHS, + DefaultDataSources.FOLDERS, train_data="train_folder", train_transform={ "to_tensor_transform": torch.as_tensor, @@ -471,7 +471,7 @@ def from_folders( **preprocess_kwargs: Any, ) -> 'DataModule': """Creates a :class:`~flash.data.data_module.DataModule` object from the given folders using the - :class:`~flash.data.data_source.DataSource` of name :attr:`~flash.data.data_source.DefaultDataSources.PATHS` + :class:`~flash.data.data_source.DataSource` of name :attr:`~flash.data.data_source.DefaultDataSources.FOLDERS` from the passed or constructed :class:`~flash.data.process.Preprocess`. Args: @@ -511,7 +511,7 @@ def from_folders( ) """ return cls.from_data_source( - DefaultDataSources.PATHS, + DefaultDataSources.FOLDERS, train_folder, val_folder, test_folder, @@ -550,7 +550,7 @@ def from_files( **preprocess_kwargs: Any, ) -> 'DataModule': """Creates a :class:`~flash.data.data_module.DataModule` object from the given sequences of files using the - :class:`~flash.data.data_source.DataSource` of name :attr:`~flash.data.data_source.DefaultDataSources.PATHS` + :class:`~flash.data.data_source.DataSource` of name :attr:`~flash.data.data_source.DefaultDataSources.FILES` from the passed or constructed :class:`~flash.data.process.Preprocess`. Args: @@ -594,7 +594,7 @@ def from_files( ) """ return cls.from_data_source( - DefaultDataSources.PATHS, + DefaultDataSources.FILES, (train_files, train_targets), (val_files, val_targets), (test_files, test_targets), @@ -677,7 +677,7 @@ def from_tensors( ) """ return cls.from_data_source( - DefaultDataSources.TENSOR, + DefaultDataSources.TENSORS, (train_data, train_targets), (val_data, val_targets), (test_data, test_targets), diff --git a/flash/data/data_source.py b/flash/data/data_source.py index bcf0a2d738..7cbe294918 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -48,9 +48,10 @@ class DefaultDataSources(LightningEnum): """The ``DefaultDataSources`` enum contains the data source names used by all of the default ``from_*`` methods in :class:`~flash.data.data_module.DataModule`.""" - PATHS = "paths" + FOLDERS = "folders" + FILES = "files" NUMPY = "numpy" - TENSOR = "tensor" + TENSORS = "tensors" CSV = "csv" JSON = "json" diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index 5aefd5d14a..22e4181134 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -146,14 +146,20 @@ def __init__( test_transform=test_transform, predict_transform=predict_transform, data_sources={ - DefaultDataSources.PATHS: VideoClassificationPathsDataSource( + DefaultDataSources.FILES: VideoClassificationPathsDataSource( clip_sampler, video_sampler=video_sampler, decode_audio=decode_audio, decoder=decoder, - ) + ), + DefaultDataSources.FOLDERS: VideoClassificationPathsDataSource( + clip_sampler, + video_sampler=video_sampler, + decode_audio=decode_audio, + decoder=decoder, + ), }, - default_data_source=DefaultDataSources.PATHS, + default_data_source=DefaultDataSources.FILES, ) def get_state_dict(self) -> Dict[str, Any]: diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index bee6456a11..8c44bce873 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -52,11 +52,12 @@ def __init__( test_transform=test_transform, predict_transform=predict_transform, data_sources={ - DefaultDataSources.PATHS: ImagePathsDataSource(), + DefaultDataSources.FILES: ImagePathsDataSource(), + DefaultDataSources.FOLDERS: ImagePathsDataSource(), DefaultDataSources.NUMPY: ImageNumpyDataSource(), - DefaultDataSources.TENSOR: ImageTensorDataSource(), + DefaultDataSources.TENSORS: ImageTensorDataSource(), }, - default_data_source=DefaultDataSources.PATHS, + default_data_source=DefaultDataSources.FILES, ) def get_state_dict(self) -> Dict[str, Any]: diff --git a/flash/vision/detection/data.py b/flash/vision/detection/data.py index 5ba6e0eebf..56078c70b5 100644 --- a/flash/vision/detection/data.py +++ b/flash/vision/detection/data.py @@ -104,10 +104,11 @@ def __init__( test_transform=test_transform, predict_transform=predict_transform, data_sources={ - DefaultDataSources.PATHS: ImagePathsDataSource(), + DefaultDataSources.FILES: ImagePathsDataSource(), + DefaultDataSources.FOLDERS: ImagePathsDataSource(), "coco": COCODataSource(), }, - default_data_source=DefaultDataSources.PATHS, + default_data_source=DefaultDataSources.FILES, ) def get_state_dict(self) -> Dict[str, Any]: diff --git a/flash/vision/embedding/model.py b/flash/vision/embedding/model.py index 1392228e37..185fa1f035 100644 --- a/flash/vision/embedding/model.py +++ b/flash/vision/embedding/model.py @@ -23,7 +23,7 @@ from flash.core.registry import FlashRegistry from flash.data.data_source import DefaultDataKeys from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES -from flash.vision.classification.data import ImageClassificationData, ImageClassificationPreprocess +from flash.vision.classification.data import ImageClassificationPreprocess class ImageEmbedder(Task): diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 882f31ef3a..ae84349aad 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -141,11 +141,12 @@ def __init__( test_transform=test_transform, predict_transform=predict_transform, data_sources={ - DefaultDataSources.PATHS: SemanticSegmentationPathsDataSource(), - DefaultDataSources.TENSOR: TensorDataSource(), + DefaultDataSources.FILES: SemanticSegmentationPathsDataSource(), + DefaultDataSources.FOLDERS: SemanticSegmentationPathsDataSource(), + DefaultDataSources.TENSORS: TensorDataSource(), DefaultDataSources.NUMPY: SemanticSegmentationNumpyDataSource(), }, - default_data_source=DefaultDataSources.PATHS, + default_data_source=DefaultDataSources.FILES, ) def get_state_dict(self) -> Dict[str, Any]: @@ -256,7 +257,7 @@ def from_folders( ) """ return cls.from_data_source( - DefaultDataSources.PATHS, + DefaultDataSources.FOLDERS, (train_folder, train_target_folder), (val_folder, val_target_folder), (test_folder, test_target_folder), diff --git a/tests/data/test_process.py b/tests/data/test_process.py index 8cf5de3dc2..02b4548dcb 100644 --- a/tests/data/test_process.py +++ b/tests/data/test_process.py @@ -142,7 +142,7 @@ def __init__(self): super().__init__( data_sources={ "test": Mock(return_value="test"), - DefaultDataSources.TENSOR: Mock(return_value="tensor"), + DefaultDataSources.TENSORS: Mock(return_value="tensors"), }, default_data_source="test", ) @@ -153,8 +153,8 @@ def test_data_source_of_name(): preprocess = CustomPreprocess() assert preprocess.data_source_of_name("test")() == "test" - assert preprocess.data_source_of_name(DefaultDataSources.TENSOR)() == "tensor" - assert preprocess.data_source_of_name("tensor")() == "tensor" + assert preprocess.data_source_of_name(DefaultDataSources.TENSORS)() == "tensors" + assert preprocess.data_source_of_name("tensors")() == "tensors" assert preprocess.data_source_of_name("default")() == "test" with pytest.raises(MisconfigurationException, match="available data sources are: test, tensor"): @@ -165,12 +165,12 @@ def test_available_data_sources(): preprocess = CustomPreprocess() - assert DefaultDataSources.TENSOR in preprocess.available_data_sources() + assert DefaultDataSources.TENSORS in preprocess.available_data_sources() assert "test" in preprocess.available_data_sources() assert len(preprocess.available_data_sources()) == 2 data_module = DataModule(preprocess=preprocess) - assert DefaultDataSources.TENSOR in data_module.available_data_sources() + assert DefaultDataSources.TENSORS in data_module.available_data_sources() assert "test" in data_module.available_data_sources() assert len(data_module.available_data_sources()) == 2 diff --git a/tests/vision/segmentation/test_model.py b/tests/vision/segmentation/test_model.py index d3f30129ff..e248271e10 100644 --- a/tests/vision/segmentation/test_model.py +++ b/tests/vision/segmentation/test_model.py @@ -87,7 +87,7 @@ def test_predict_tensor(): img = torch.rand(1, 3, 10, 20) model = SemanticSegmentation(2) data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess()) - out = model.predict(img, data_source="tensor", data_pipeline=data_pipe) + out = model.predict(img, data_source="tensors", data_pipeline=data_pipe) assert isinstance(out[0], torch.Tensor) assert out[0].shape == (196, 196)