Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[1/N] Data Sources #256

Merged
merged 84 commits into from
May 7, 2021
Merged
Show file tree
Hide file tree
Changes from 75 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
735740e
Initial commit
ethanwharris Apr 28, 2021
be01397
POC Initial commit
ethanwharris Apr 29, 2021
214df85
Remove unused code
ethanwharris Apr 29, 2021
8f93bfb
Some fixes
ethanwharris Apr 30, 2021
e8ee4c0
Simplify data source
ethanwharris Apr 30, 2021
653057d
Expand preprocess
ethanwharris Apr 30, 2021
0184332
Fixes
ethanwharris Apr 30, 2021
5172a06
Fixes
ethanwharris Apr 30, 2021
5c3f597
Cleaning
ethanwharris Apr 30, 2021
44d70e1
Fixes
ethanwharris Apr 30, 2021
08657ea
Remove un-needed code
ethanwharris May 4, 2021
73be792
Remove sequence data source
ethanwharris May 4, 2021
3381840
Simplify data source
ethanwharris May 4, 2021
e01987d
Fix FilesDataSource
ethanwharris May 4, 2021
e385dfa
Minor fix
ethanwharris May 4, 2021
dc90754
Add numpy and tesnor data sources
ethanwharris May 4, 2021
c437043
Fixes
ethanwharris May 4, 2021
b32ee34
Onboard object detection
ethanwharris May 5, 2021
bfd320d
update
tchaton May 5, 2021
7e050be
Add text classification
ethanwharris May 5, 2021
34c41d4
Merge branch 'feature/data_sources' of https://github.com/PyTorchLigh…
ethanwharris May 5, 2021
6e0f69d
Small update
ethanwharris May 5, 2021
a2082bc
Add tabular
ethanwharris May 5, 2021
fd07644
Fixes
ethanwharris May 5, 2021
19e966d
Fixes
ethanwharris May 5, 2021
3b7ab0e
Add summarization example
ethanwharris May 6, 2021
d9c00c5
Add translation
ethanwharris May 6, 2021
2da3339
Merge branch 'master' into feature/data_sources
ethanwharris May 6, 2021
d5b8c4a
assert empty data_source in datapipeline creation
edgarriba May 6, 2021
dd35da6
add more assertions for test_classification_task_predict_folder_path
edgarriba May 6, 2021
f2c3f20
Add video
ethanwharris May 6, 2021
e186926
Merge branch 'feature/data_sources' of https://github.com/PyTorchLigh…
ethanwharris May 6, 2021
83024bb
add smoke tests for autodataset
edgarriba May 6, 2021
8309080
improve autodataset test
edgarriba May 6, 2021
f1c44a1
Fix some tests
ethanwharris May 6, 2021
47e8f3f
Fix a test
ethanwharris May 6, 2021
f3a238e
Fixes
ethanwharris May 6, 2021
eb5cfdd
add tests for base and iterable
edgarriba May 6, 2021
a997b9d
add todo with detected error in callbacks test
edgarriba May 6, 2021
b18f0fd
fix test_data_pipeline_init_and_assignement
edgarriba May 6, 2021
bda0a12
fix test_data_pipeline_is_overriden_and_resolve_function_hierarchy
edgarriba May 6, 2021
e4a4f8a
Fix some tests
ethanwharris May 6, 2021
f5f000f
Merge branch 'feature/data_sources' of https://github.com/PyTorchLigh…
ethanwharris May 6, 2021
464fffe
Fix some tests
ethanwharris May 6, 2021
e7d6b66
Fix some tests
ethanwharris May 6, 2021
cc57f86
Fix a test
ethanwharris May 6, 2021
3a63083
Fixes
ethanwharris May 6, 2021
3489953
Fixes
ethanwharris May 6, 2021
1ccf7ab
Fixes
ethanwharris May 6, 2021
64aff9e
Fixes
ethanwharris May 6, 2021
33506b3
Fixes
ethanwharris May 6, 2021
1f50432
deprecate csv test for image classification
edgarriba May 6, 2021
6b587fe
Fix video
ethanwharris May 6, 2021
2794a98
Merge branch 'feature/data_sources' of https://github.com/PyTorchLigh…
ethanwharris May 6, 2021
4215a47
fix test_from_filepaths_splits
edgarriba May 6, 2021
0256c04
Fixes
ethanwharris May 6, 2021
9806b85
Merge branch 'feature/data_sources' of https://github.com/PyTorchLigh…
ethanwharris May 6, 2021
1d5c41b
Fixes
ethanwharris May 6, 2021
8064c65
Fixes
ethanwharris May 6, 2021
a32560c
Fixes
ethanwharris May 6, 2021
4d34d94
Fixes
ethanwharris May 6, 2021
c85a8db
Fixes
ethanwharris May 6, 2021
c93a649
Fixes
ethanwharris May 6, 2021
02fd77b
Fixes
ethanwharris May 6, 2021
3d780fa
Fix docs build
ethanwharris May 6, 2021
704f558
Fix docs build
ethanwharris May 7, 2021
edfc38e
Fix examples
ethanwharris May 7, 2021
4679cb5
Fixes
ethanwharris May 7, 2021
05a1e98
Fixes
ethanwharris May 7, 2021
46b6a4f
Bump huggingface minimal
ethanwharris May 7, 2021
5b2013e
debugging
ethanwharris May 7, 2021
75f3469
debugging
ethanwharris May 7, 2021
950b13f
Fixes
ethanwharris May 7, 2021
f47208c
Fixes
ethanwharris May 7, 2021
db0c991
Respond to comments
ethanwharris May 7, 2021
db1cdf1
feedback
ethanwharris May 7, 2021
88cbc65
Updates
ethanwharris May 7, 2021
4ee1dd4
Fixes
ethanwharris May 7, 2021
ce3fcf2
Fixes
ethanwharris May 7, 2021
ed22b10
revert
ethanwharris May 7, 2021
f453d03
Updates
ethanwharris May 7, 2021
1ae8c56
Fixes
ethanwharris May 7, 2021
1088022
Fixes
ethanwharris May 7, 2021
9032be4
Fixes
ethanwharris May 7, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ Here are common terms you need to be familiar with:
- The :class:`~flash.data.data_module.DataModule` contains the dataset, transforms and dataloaders.
* - :class:`~flash.data.data_pipeline.DataPipeline`
- The :class:`~flash.data.data_pipeline.DataPipeline` is Flash internal object to manage :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` objects.
* - :class:`~flash.data.data_source.DataSource`
- The :class:`~flash.data.data_source.DataSource` provides a hook-based API for creating data sets.
* - :class:`~flash.data.process.Preprocess`
- The :class:`~flash.data.process.Preprocess` provides a simple hook-based API to encapsulate your pre-processing logic.
The :class:`~flash.data.process.Preprocess` provides multiple hooks such as :meth:`~flash.data.process.Preprocess.load_data`
Expand Down Expand Up @@ -275,6 +277,17 @@ Example::
API reference
*************

.. _data_source:

DataSource
__________

.. autoclass:: flash.data.data_source.DataSource
:members:


----------

.. _preprocess:

Preprocess
Expand Down Expand Up @@ -325,7 +338,6 @@ __________

.. autoclass:: flash.data.data_module.DataModule
:members:
from_load_data_inputs,
train_dataset,
val_dataset,
test_dataset,
Expand Down
4 changes: 0 additions & 4 deletions docs/source/reference/image_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,4 @@ ImageClassificationData

.. autoclass:: flash.vision.ImageClassificationData

.. automethod:: flash.vision.ImageClassificationData.from_filepaths

.. automethod:: flash.vision.ImageClassificationData.from_folders

.. autoclass:: flash.vision.ImageClassificationPreprocess
2 changes: 1 addition & 1 deletion docs/source/reference/tabular_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,4 @@ TabularData

.. automethod:: flash.tabular.TabularData.from_csv

.. automethod:: flash.tabular.TabularData.from_df
.. automethod:: flash.tabular.TabularData.from_data_frame
2 changes: 0 additions & 2 deletions docs/source/reference/video_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,5 +152,3 @@ VideoClassificationData
-----------------------

.. autoclass:: flash.video.VideoClassificationData

.. automethod:: flash.video.VideoClassificationData.from_paths
21 changes: 8 additions & 13 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Callable, List, Mapping, Optional, Sequence, Union

import torch
Expand All @@ -20,20 +19,15 @@
from pytorch_lightning.utilities import rank_zero_warn

from flash.core.model import Task
from flash.data.process import ProcessState, Serializer
from flash.data.data_source import LabelsState
from flash.data.process import Serializer


def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Calls BCE with logits and cast the target one_hot (y) encoding to floating point precision."""
return F.binary_cross_entropy_with_logits(x, y.float())


@dataclass(unsafe_hash=True, frozen=True)
class ClassificationState(ProcessState):

labels: Optional[List[str]]


class ClassificationTask(Task):

def __init__(
Expand Down Expand Up @@ -130,7 +124,7 @@ class Labels(Classes):
Args:
labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not
provided, will attempt to get them from the :class:`.ClassificationState`.
provided, will attempt to get them from the :class:`.LabelsState`.
multi_label: If true, treats outputs as multi label logits.
Expand All @@ -141,13 +135,16 @@ def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False
super().__init__(multi_label=multi_label, threshold=threshold)
self._labels = labels

if labels is not None:
self.set_state(LabelsState(labels))

def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]:
labels = None

if self._labels is not None:
labels = self._labels
else:
state = self.get_state(ClassificationState)
state = self.get_state(LabelsState)
if state is not None:
labels = state.labels

Expand All @@ -158,7 +155,5 @@ def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]:
return [labels[cls] for cls in classes]
return labels[classes]
else:
rank_zero_warn(
"No ClassificationState was found, this serializer will act as a Classes serializer.", UserWarning
)
rank_zero_warn("No LabelsState was found, this serializer will act as a Classes serializer.", UserWarning)
return classes
45 changes: 36 additions & 9 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
from importlib import import_module
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
import inspect

import torch
import torchmetrics
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.optim.lr_scheduler import _LRScheduler
Expand All @@ -29,7 +29,8 @@
from flash.core.registry import FlashRegistry
from flash.core.schedulers import _SCHEDULERS_REGISTRY
from flash.core.utils import get_callable_dict
from flash.data.data_pipeline import DataPipeline
from flash.data.data_pipeline import DataPipeline, DataPipelineState
from flash.data.data_source import DataSource, DefaultDataSources
from flash.data.process import Postprocess, Preprocess, Serializer, SerializerMapping


Expand Down Expand Up @@ -103,6 +104,9 @@ def __init__(
self._postprocess: Optional[Postprocess] = postprocess
self._serializer: Optional[Serializer] = None

# TODO: create enum values to define what are the exact states
self._data_pipeline_state: Optional[DataPipelineState] = None

# Explicitly set the serializer to call the setter
self.serializer = serializer

Expand Down Expand Up @@ -154,6 +158,7 @@ def test_step(self, batch: Any, batch_idx: int) -> None:
def predict(
self,
x: Any,
data_source: str = "default",
data_pipeline: Optional[DataPipeline] = None,
) -> Any:
"""
Expand All @@ -169,9 +174,9 @@ def predict(
"""
running_stage = RunningStage.PREDICTING

data_pipeline = self.build_data_pipeline(data_pipeline)
data_pipeline = self.build_data_pipeline(data_source, data_pipeline)

x = [x for x in data_pipeline._generate_auto_dataset(x, running_stage)]
x = [x for x in data_pipeline._data_source.generate_dataset(x, running_stage)]
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
x = data_pipeline.worker_preprocessor(running_stage)(x)
# switch to self.device when #7188 merge in Lightning
x = self.transfer_batch_to_device(x, next(self.parameters()).device)
Expand Down Expand Up @@ -252,7 +257,11 @@ def serializer(self, serializer: Union[Serializer, Mapping[str, Serializer]]):
serializer = SerializerMapping(serializer)
self._serializer = serializer

def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> Optional[DataPipeline]:
def build_data_pipeline(
self,
data_source: Optional[str] = None,
data_pipeline: Optional[DataPipeline] = None,
) -> Optional[DataPipeline]:
"""Build a :class:`.DataPipeline` incorporating available
:class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`
objects. These will be overridden in the following resolution order (lowest priority first):
Expand All @@ -269,20 +278,26 @@ def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> O
Returns:
The fully resolved :class:`.DataPipeline`.
"""
preprocess, postprocess, serializer = None, None, None
old_data_source, preprocess, postprocess, serializer = None, None, None, None

# Datamodule
if self.datamodule is not None and getattr(self.datamodule, 'data_pipeline', None) is not None:
old_data_source = getattr(self.datamodule.data_pipeline, '_data_source', None)
preprocess = getattr(self.datamodule.data_pipeline, '_preprocess_pipeline', None)
postprocess = getattr(self.datamodule.data_pipeline, '_postprocess_pipeline', None)
serializer = getattr(self.datamodule.data_pipeline, '_serializer', None)

elif self.trainer is not None and hasattr(
self.trainer, 'datamodule'
) and getattr(self.trainer.datamodule, 'data_pipeline', None) is not None:
old_data_source = getattr(self.trainer.datamodule.data_pipeline, '_data_source', None)
preprocess = getattr(self.trainer.datamodule.data_pipeline, '_preprocess_pipeline', None)
postprocess = getattr(self.trainer.datamodule.data_pipeline, '_postprocess_pipeline', None)
serializer = getattr(self.trainer.datamodule.data_pipeline, '_serializer', None)
else:
# TODO: we should log with low severity level that we use defaults to create
# `preprocess`, `postprocess` and `serializer`.
pass

# Defaults / task attributes
preprocess, postprocess, serializer = Task._resolve(
Expand All @@ -305,8 +320,16 @@ def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> O
getattr(data_pipeline, '_serializer', None),
)

data_pipeline = DataPipeline(preprocess, postprocess, serializer)
data_pipeline.initialize()
data_source = data_source or old_data_source

if isinstance(data_source, str):
if preprocess is None:
data_source = DataSource() # TODO: warn the user that we are not using the specified data source
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we have some kind of registry (similar to what we do with the backbones) and look the source up there if no preprocess is given (in fact this should also be the default behaviour of the preprocess then)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some ideas for a data sources registry. Not sure it makes sense to just map strings to data sources as most data sources only work with particular preprocesses.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but the registry would most likely be task specific as well (similar to backbones).

else:
data_source = preprocess.data_source_of_name(data_source)
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved

data_pipeline = DataPipeline(data_source, preprocess, postprocess, serializer)
self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state)
return data_pipeline

@property
Expand Down Expand Up @@ -376,12 +399,16 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# https://pytorch.org/docs/stable/notes/serialization.html
if self.data_pipeline is not None and 'data_pipeline' not in checkpoint:
checkpoint['data_pipeline'] = self.data_pipeline
if self._data_pipeline_state is not None and '_data_pipeline_state' not in checkpoint:
checkpoint['_data_pipeline_state'] = self._data_pipeline_state
super().on_save_checkpoint(checkpoint)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
super().on_load_checkpoint(checkpoint)
if 'data_pipeline' in checkpoint:
self.data_pipeline = checkpoint['data_pipeline']
if '_data_pipeline_state' in checkpoint:
self._data_pipeline_state = checkpoint['_data_pipeline_state']

@classmethod
def available_backbones(cls) -> List[str]:
Expand Down
Loading