Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Swap PATHS for FOLDERS and FILES and change TENSOR -> TENSORS (#289)
Browse files Browse the repository at this point in the history
* Swap PATHS for FOLDERS and FILES, and change TENSOR -> TENSORS

* Fix a test

* Fixes

* Fixes
  • Loading branch information
ethanwharris authored May 12, 2021
1 parent 87bdb09 commit 1380c53
Show file tree
Hide file tree
Showing 10 changed files with 41 additions and 30 deletions.
7 changes: 4 additions & 3 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,12 @@ Example::
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.PATHS: ImagePathsDataSource(),
DefaultDataSources.FILES: ImagePathsDataSource(),
DefaultDataSources.FOLDERS: ImagePathsDataSource(),
DefaultDataSources.NUMPY: ImageNumpyDataSource(),
DefaultDataSources.TENSOR: ImageTensorDataSource(),
DefaultDataSources.TENSORS: ImageTensorDataSource(),
},
default_data_source=DefaultDataSources.PATHS,
default_data_source=DefaultDataSources.FILES,
)

def get_state_dict(self) -> Dict[str, Any]:
Expand Down
12 changes: 6 additions & 6 deletions flash/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def from_data_source(
Examples::
data_module = DataModule.from_data_source(
DefaultDataSources.PATHS,
DefaultDataSources.FOLDERS,
train_data="train_folder",
train_transform={
"to_tensor_transform": torch.as_tensor,
Expand Down Expand Up @@ -471,7 +471,7 @@ def from_folders(
**preprocess_kwargs: Any,
) -> 'DataModule':
"""Creates a :class:`~flash.data.data_module.DataModule` object from the given folders using the
:class:`~flash.data.data_source.DataSource` of name :attr:`~flash.data.data_source.DefaultDataSources.PATHS`
:class:`~flash.data.data_source.DataSource` of name :attr:`~flash.data.data_source.DefaultDataSources.FOLDERS`
from the passed or constructed :class:`~flash.data.process.Preprocess`.
Args:
Expand Down Expand Up @@ -511,7 +511,7 @@ def from_folders(
)
"""
return cls.from_data_source(
DefaultDataSources.PATHS,
DefaultDataSources.FOLDERS,
train_folder,
val_folder,
test_folder,
Expand Down Expand Up @@ -550,7 +550,7 @@ def from_files(
**preprocess_kwargs: Any,
) -> 'DataModule':
"""Creates a :class:`~flash.data.data_module.DataModule` object from the given sequences of files using the
:class:`~flash.data.data_source.DataSource` of name :attr:`~flash.data.data_source.DefaultDataSources.PATHS`
:class:`~flash.data.data_source.DataSource` of name :attr:`~flash.data.data_source.DefaultDataSources.FILES`
from the passed or constructed :class:`~flash.data.process.Preprocess`.
Args:
Expand Down Expand Up @@ -594,7 +594,7 @@ def from_files(
)
"""
return cls.from_data_source(
DefaultDataSources.PATHS,
DefaultDataSources.FILES,
(train_files, train_targets),
(val_files, val_targets),
(test_files, test_targets),
Expand Down Expand Up @@ -677,7 +677,7 @@ def from_tensors(
)
"""
return cls.from_data_source(
DefaultDataSources.TENSOR,
DefaultDataSources.TENSORS,
(train_data, train_targets),
(val_data, val_targets),
(test_data, test_targets),
Expand Down
5 changes: 3 additions & 2 deletions flash/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ class DefaultDataSources(LightningEnum):
"""The ``DefaultDataSources`` enum contains the data source names used by all of the default ``from_*`` methods in
:class:`~flash.data.data_module.DataModule`."""

PATHS = "paths"
FOLDERS = "folders"
FILES = "files"
NUMPY = "numpy"
TENSOR = "tensor"
TENSORS = "tensors"
CSV = "csv"
JSON = "json"

Expand Down
12 changes: 9 additions & 3 deletions flash/video/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,20 @@ def __init__(
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.PATHS: VideoClassificationPathsDataSource(
DefaultDataSources.FILES: VideoClassificationPathsDataSource(
clip_sampler,
video_sampler=video_sampler,
decode_audio=decode_audio,
decoder=decoder,
)
),
DefaultDataSources.FOLDERS: VideoClassificationPathsDataSource(
clip_sampler,
video_sampler=video_sampler,
decode_audio=decode_audio,
decoder=decoder,
),
},
default_data_source=DefaultDataSources.PATHS,
default_data_source=DefaultDataSources.FILES,
)

def get_state_dict(self) -> Dict[str, Any]:
Expand Down
7 changes: 4 additions & 3 deletions flash/vision/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,12 @@ def __init__(
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.PATHS: ImagePathsDataSource(),
DefaultDataSources.FILES: ImagePathsDataSource(),
DefaultDataSources.FOLDERS: ImagePathsDataSource(),
DefaultDataSources.NUMPY: ImageNumpyDataSource(),
DefaultDataSources.TENSOR: ImageTensorDataSource(),
DefaultDataSources.TENSORS: ImageTensorDataSource(),
},
default_data_source=DefaultDataSources.PATHS,
default_data_source=DefaultDataSources.FILES,
)

def get_state_dict(self) -> Dict[str, Any]:
Expand Down
5 changes: 3 additions & 2 deletions flash/vision/detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,11 @@ def __init__(
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.PATHS: ImagePathsDataSource(),
DefaultDataSources.FILES: ImagePathsDataSource(),
DefaultDataSources.FOLDERS: ImagePathsDataSource(),
"coco": COCODataSource(),
},
default_data_source=DefaultDataSources.PATHS,
default_data_source=DefaultDataSources.FILES,
)

def get_state_dict(self) -> Dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion flash/vision/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from flash.core.registry import FlashRegistry
from flash.data.data_source import DefaultDataKeys
from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES
from flash.vision.classification.data import ImageClassificationData, ImageClassificationPreprocess
from flash.vision.classification.data import ImageClassificationPreprocess


class ImageEmbedder(Task):
Expand Down
9 changes: 5 additions & 4 deletions flash/vision/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,12 @@ def __init__(
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.PATHS: SemanticSegmentationPathsDataSource(),
DefaultDataSources.TENSOR: TensorDataSource(),
DefaultDataSources.FILES: SemanticSegmentationPathsDataSource(),
DefaultDataSources.FOLDERS: SemanticSegmentationPathsDataSource(),
DefaultDataSources.TENSORS: TensorDataSource(),
DefaultDataSources.NUMPY: SemanticSegmentationNumpyDataSource(),
},
default_data_source=DefaultDataSources.PATHS,
default_data_source=DefaultDataSources.FILES,
)

def get_state_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -256,7 +257,7 @@ def from_folders(
)
"""
return cls.from_data_source(
DefaultDataSources.PATHS,
DefaultDataSources.FOLDERS,
(train_folder, train_target_folder),
(val_folder, val_target_folder),
(test_folder, test_target_folder),
Expand Down
10 changes: 5 additions & 5 deletions tests/data/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __init__(self):
super().__init__(
data_sources={
"test": Mock(return_value="test"),
DefaultDataSources.TENSOR: Mock(return_value="tensor"),
DefaultDataSources.TENSORS: Mock(return_value="tensors"),
},
default_data_source="test",
)
Expand All @@ -153,8 +153,8 @@ def test_data_source_of_name():
preprocess = CustomPreprocess()

assert preprocess.data_source_of_name("test")() == "test"
assert preprocess.data_source_of_name(DefaultDataSources.TENSOR)() == "tensor"
assert preprocess.data_source_of_name("tensor")() == "tensor"
assert preprocess.data_source_of_name(DefaultDataSources.TENSORS)() == "tensors"
assert preprocess.data_source_of_name("tensors")() == "tensors"
assert preprocess.data_source_of_name("default")() == "test"

with pytest.raises(MisconfigurationException, match="available data sources are: test, tensor"):
Expand All @@ -165,12 +165,12 @@ def test_available_data_sources():

preprocess = CustomPreprocess()

assert DefaultDataSources.TENSOR in preprocess.available_data_sources()
assert DefaultDataSources.TENSORS in preprocess.available_data_sources()
assert "test" in preprocess.available_data_sources()
assert len(preprocess.available_data_sources()) == 2

data_module = DataModule(preprocess=preprocess)

assert DefaultDataSources.TENSOR in data_module.available_data_sources()
assert DefaultDataSources.TENSORS in data_module.available_data_sources()
assert "test" in data_module.available_data_sources()
assert len(data_module.available_data_sources()) == 2
2 changes: 1 addition & 1 deletion tests/vision/segmentation/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_predict_tensor():
img = torch.rand(1, 3, 10, 20)
model = SemanticSegmentation(2)
data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess())
out = model.predict(img, data_source="tensor", data_pipeline=data_pipe)
out = model.predict(img, data_source="tensors", data_pipeline=data_pipe)
assert isinstance(out[0], torch.Tensor)
assert out[0].shape == (196, 196)

Expand Down

0 comments on commit 1380c53

Please sign in to comment.